网络可视化、风格迁移

本文介绍了利用深度学习进行图像可视化和风格迁移的技术。通过预训练的SqueezeNet模型,计算内容损失、风格损失和总变分损失,实现从原始图像到目标风格的转换。内容包括加载ImageNet数据、计算梯度、生成fooling图像、特征反演和纹理合成等。此外,还讨论了关键函数如PyTorch的gather方法、Image.fromarray以及损失函数的计算。
摘要由CSDN通过智能技术生成

理论部分:CS231n 笔记 神经网络可视化(上)_iwill323的博客-CSDN博客_卷积神经网络可视化

CS231n 2022PPT笔记- 神经网络可视化(下)神经风格迁移_iwill323的博客-CSDN博客

目录

准备工作

导包

Pretrained Model

Loading ImageNet Validation Images

图像处理函数

Saliency Maps

Fooling Images

Class Visualization

Style Transfer

方法

预训练模型

Computing Loss

Content loss

Style loss

Total-variation regularization

模型

特征提取

主函数

初始化函数

生成图像

Feature Inversion特征反演

texture synthesis

需要注意的函数

PyTorch gather method

Image.fromarray

max

MSE


We will start from a CNN model which has been pretrained to perform image classification on the ImageNet dataset. We will use this model to define a loss function which quantifies our current unhappiness with our image. Then we will use backpropagation to compute the gradient of this loss with respect to the pixels of the image. We will then keep the model fixed and perform gradient descent on the image to synthesize a new image which minimizes the loss.

We will explore three techniques for image generation.

准备工作

导包

# Setup cell.
import torch
import torchvision
import torch.nn as nn
import numpy as np
import random
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms


%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # Set default size of plots.
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

SQUEEZENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
SQUEEZENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 

%load_ext autoreload
%autoreload 2

Pretrained Model

For the purposes of this assignment we will use SqueezeNet [1], which achieves accuracies comparable to AlexNet but with a significantly reduced parameter count and computational complexity. Using SqueezeNet rather than AlexNet or VGG or ResNet means that we can easily perform all image generation experiments on CPU.

[1] Iandola et al, "SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and < 0.5MB model size", arXiv 2016

# Download and load the pretrained SqueezeNet model.
model = torchvision.models.squeezenet1_1(pretrained=True)

# We don't want to train the model, so tell PyTorch not to compute gradients
# with respect to model parameters.
for param in model.parameters():
    param.requires_grad = False

Loading ImageNet Validation Images

# http://cs231n.stanford.edu/imagenet_val_25.npz
from cs231n.data_utils import load_imagenet_val
X, y, class_names = load_imagenet_val(num=5)

plt.figure(figsize=(12, 6))
for i in range(5):
    plt.subplot(1, 5, i + 1)
    plt.imshow(X[i])
    plt.title(class_names[y[i]])
    plt.axis('off')
plt.gcf().tight_layout()

用到的函数 load_imagenet_val

from __future__ import print_function

from builtins import range
from six.moves import cPickle as pickle
import numpy as np
import os
from imageio import imread
import platform


def load_imagenet_val(num=None):
    """Load a handful of validation images from ImageNet.

    Inputs:
    - num: Number of images to load (max of 25)

    Returns:
    - X: numpy array with shape [num, 224, 224, 3]
    - y: numpy array of integer image labels, shape [num]
    - class_names: dict mapping integer label to class name
    """
    imagenet_fn = os.path.join(
        os.path.dirname(__file__), "datasets/imagenet_val_25.npz"
    )
    if not os.path.isfile(imagenet_fn):
        print("file %s not found" % imagenet_fn)
        print("Run the following:")
        print("cd cs231n/datasets")
        print("bash get_imagenet_val.sh")
        assert False, "Need to download imagenet_val_25.npz"

    # modify the default parameters of np.load
    # https://stackoverflow.com/questions/55890813/how-to-fix-object-arrays-cannot-be-loaded-when-allow-pickle-false-for-imdb-loa
    np_load_old = np.load
    np.load = lambda *a,**k: np_load_old(*a, allow_pickle=True, **k)
    f = np.load(imagenet_fn)
    np.load = np_load_old
    X = f["X"]
    y = f["y"]
    class_names = f["label_map"].item()
    if num is not None:
        X = X[:num]
        y = y[:num]
    return X, y, class_names

 确定一下X,y的性质

>>print(X.shape)  #  5张图片

(5, 224, 224, 3)
>>print(y)

[958 85 244 182 294]

>>for x in X:
        print(x.shape) 

(224, 224, 3)
(224, 224, 3)
(224, 224, 3)
(224, 224, 3)
(224, 224, 3)

图像处理函数

def preprocess(img, size=224):
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize(mean=SQUEEZENET_MEAN.tolist(),
                    std=SQUEEZENET_STD.tolist()),
        transforms.Lambda(lambda x: x[None]),
        # Add a batch dimension in the first position of the tensor: 
        # aka, a tensor of shape (H, W, C) will become -> (1, H, W, C).
    ])
    # return transforms(img).unsqueeze(0)  也可以使用这种方式增加第一维
    return transform(img)


def deprocess(img, should_rescale=True):
    """ 
    De-processes a Pytorch tensor from the output of the CNN model 
    to become a PIL JPG Image 
    """
    transform = transforms.Compose([        
        transforms.Lambda(lambda x: x[0]),  # Remove the batch dimension at the first position. A tensor of dims (1, H, W, C) will become -> (H, W, C).         
        transforms.Normalize(mean=[0, 0, 0], std=(1.0 / SQUEEZENET_STD).tolist()),  # Normalize the standard deviation
        transforms.Normalize(mean=(-SQUEEZENET_MEAN).tolist(), std=[1, 1, 1]),  # Normalize the mean
        # Rescale all the values in the tensor so that they lie in the interval [0, 1] to prepare for transforming it into image pixel values.
        transforms.Lambda(rescale) if should_rescale else transforms.Lambda(lambda x: x), 
        transforms.ToPILImage(),
    ])
    return transform(img)


def rescale(x):
    """ A function used internally inside `deprocess`.
        Rescale elements of x linearly to be in the interval [0, 1]
        with the minimum element(s) mapped to 0, and the maximum element(s)
        mapped to 1.
    """
    low, high = x.min(), x.max()
    x_rescaled = (x - low) / (high - low)
    return x_rescaled

也可以使用以下deprocess函数处理。输入img是一个pytorch的tensor(包含batch维)对象,而方差和均值数据是numpy数组,所以在计算前要改变img的通道位置,计算完后要将通道变回去

def deprocess(img):
    img = img[0]
    img = torch.clamp(img.permute(1, 2, 0) * SQUEEZENET_STD + SQUEEZENET_MEAN, 0, 1)
    return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))

Saliency Maps

We can use saliency maps to tell which part of the image influenced the classification decision made by the network.

A saliency map tells us the degree to which each pixel in the image affects the classification score for that image. To compute it, we compute the gradient of the unnormalized score corresponding to the correct class (which is a scalar) with respect to the pixels of the image. If the image has shape (3, H, W) then this gradient will also have shape (3, H, W); for each pixel in the image, this gradient tells us the amount by which the classification score will change if the pixel changes by a small amount. To compute the saliency map, we take the absolute value of this gradient, then take the maximum value over the 3 input channels; the final saliency map thus has shape (H, W) and all entries are nonnegative.

import torch
import random
import torchvision.transforms as transforms
import numpy as np
import torch.nn as nn
from scipy.ndimage.filters import gaussian_filter1d

def compute_saliency_maps(X, y, model):
    """
    Compute a class saliency map using the model for images X and labels y.

    Input:
    - X: Input images; Tensor of shape (N, 3, H, W)
    - y: Labels for X; LongTensor of shape (N,)
    - model: A pretrained CNN that will be used to compute the saliency map.

    Returns:
    - saliency
  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值