【PyTorch】14 AI艺术家:神经网络风格迁移

详细可参考此CSDN

1、数据集

使用COCO数据集,官方网站点,下载点,共13.5GB,82783张图片

2、原理简介

风格迁移分为两类,一类为风格图片(毕加索、梵高…),一类是内容图片,通常来自现实世界中

本文主要介绍Fast Neural Style,关于Neural Style可见pytorch官方教程

效果逼真的风格迁移图片有两个要求,一是要生成的图片在内容、细节上和输入的内容图片保持一致,二是要生成的图片在风格上尽可能与风格图片保持一致。相应的,定义两个损失:content loss(比较常用的做法是逐像素计算差值,又称pixelwise loss;以及计算图像在更高语义上的差异的perceptual loss)和style loss

一般使用Gram矩阵来表示图像的风格特征。对于每一张图片,卷积层的输出形状为C×H×W,C为卷积核的通道数,一般称为有C个卷积核,每个卷积核学习图像的不同特征。每一个卷积核输出的H×W代表这张图片的一个feature map。对于每个C×H×W的feature map F,Gram Matrix的形状为C×C,其第i、j个元素 G i , j G_{i,j} Gi,j的计算定义如下:
G i , j = ∑ k F i k F j k G_{i, j}=\sum_{k} F_{i k} F_{j k} Gi,j=kFikFjk

其中 F i k F_{ik} Fik代表了第i个feature map的第k个像素点。关于Gram Matrix,以下三点可以注意:

  • Gram Matrix计算采用累加的方式,抛弃了空间信息,一张图片的像素随机打乱之后计算得到的Gram Matrix和原图的Gram Matrix一样,但是纹理、色彩等风格信息被保存下来
  • Gram Matrix的结果和feature map的尺度无关,只与通道数C有关
  • 对于一个C×H×W的feature map,可将其调整为C×(HW)的二维矩阵,然后再计算 F ⋅ F T F \cdot F^{T} FFT,结果即为Gram Matrix

Fast Neural Style的网络结构如下所示:

在这里插入图片描述

3、用Pytorch实现风格迁移

风格迁移网络的实现参考了PyTorch官方示例,其网络结构如下所示:
在这里插入图片描述
content_losses:
在这里插入图片描述
style_losses:
在这里插入图片描述

4、结果展示

期望风格:

在这里插入图片描述

对以下图片进行风格迁移:

示例1:
在这里插入图片描述
在这里插入图片描述
示例2:
在这里插入图片描述
在这里插入图片描述
示例3:
在这里插入图片描述
在这里插入图片描述
示例4:
在这里插入图片描述
在这里插入图片描述

5、全部代码

import torch
from torch import nn
from torchvision.models import vgg16
from collections import namedtuple
import numpy as np
import torchvision as tv
from torch.utils import data
import tqdm
from torch.nn import functional as F
import matplotlib.pyplot as plt


class Vgg16(nn.Module):
    def __init__(self):
        super(Vgg16, self).__init__()
        features = list(vgg16(pretrained=True).features)[:23]
        self.features = nn.ModuleList(features).eval()

    def forward(self, x):
        results = []
        # feature的第3,8,15,22层分别是:relu1_2,relu2_2,relu3_3,relu4_3
        for ii, model in enumerate(self.features):
            x = model(x)
            if ii in {3, 8, 15, 22}:
                results.append(x)

        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
        return vgg_outputs(*results)


class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        reflection_padding = int(np.floor(kernel_size / 2))
        self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out


class UpsampleConvLayer(nn.Module):
    """UpsampleConvLayer
    instead of ConvTranspose2d, we do UpSample + Conv2d
    see ref for why.
    ref: http://distill.pub/2016/deconv-checkerboard/
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__()
        self.upsample = upsample
        reflection_padding = int(np.floor(kernel_size / 2))
        self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        x_in = x
        if self.upsample:
            x_in = torch.nn.functional.interpolate(x_in, scale_factor=self.upsample)
        out = self.reflection_pad(x_in)
        out = self.conv2d(out)
        return out

class ResidualBlock(nn.Module):
    """ResidualBlock
    introduced in: https://arxiv.org/abs/1512.03385
    recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
    """

    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in1 = nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in2 = nn.InstanceNorm2d(channels, affine=True)
        self.relu = nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.relu(self.in1(self.conv1(x)))
        out = self.in2(self.conv2(out))
        out = out + residual
        return out


class TransformerNet(nn.Module):
    def __init__(self):
        super(TransformerNet, self).__init__()

        # Down sample layers
        self.initial_layers = nn.Sequential(
            ConvLayer(3, 32, kernel_size=9, stride=1),
            nn.InstanceNorm2d(32, affine=True),
            nn.ReLU(True),
            ConvLayer(32, 64, kernel_size=3, stride=2),
            nn.InstanceNorm2d(64, affine=True),
            nn.ReLU(True),
            ConvLayer(64, 128, kernel_size=3, stride=2),
            nn.InstanceNorm2d(128, affine=True),
            nn.ReLU(True),
        )

        # Residual layers
        self.res_layers = nn.Sequential(
            ResidualBlock(128),
            ResidualBlock(128),
            ResidualBlock(128),
            ResidualBlock(128),
            ResidualBlock(128)
        )

        # Upsampling Layers
        self.upsample_layers = nn.Sequential(
            UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2),
            nn.InstanceNorm2d(64, affine=True),
            nn.ReLU(True),
            UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2),
            nn.InstanceNorm2d(32, affine=True),
            nn.ReLU(True),
            ConvLayer(32, 3, kernel_size=9, stride=1)
        )

    def forward(self, x):
        x = self.initial_layers(x)
        x = self.res_layers(x)
        x = self.upsample_layers(x)
        return x


IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

def gram_matrix(y):
    """
    Input shape: b,c,h,w
    Output shape: b,c,c
    """
    (b, ch, h, w) = y.size()
    features = y.view(b, ch, w * h)
    features_t = features.transpose(1, 2)
    gram = features.bmm(features_t) / (ch * h * w)
    return gram

def get_style_data(path):
    """
    load style image,
    Return: tensor shape 1*c*h*w, normalized
    """
    style_transform = tv.transforms.Compose([
        tv.transforms.ToTensor(),
        tv.transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ])

    style_image = tv.datasets.folder.default_loader(path)
    style_tensor = style_transform(style_image)
    return style_tensor.unsqueeze(0)


def normalize_batch(batch):
    """
    Input: b,ch,h,w  0~255
    Output: b,ch,h,w  -2~2
    """
    mean = batch.data.new(IMAGENET_MEAN).view(1, -1, 1, 1)
    std = batch.data.new(IMAGENET_STD).view(1, -1, 1, 1)
    mean = (mean.expand_as(batch.data))
    std = (std.expand_as(batch.data))
    return (batch / 255.0 - mean) / std


class Config(object):
    # General Args
    # use_gpu = True
    # model_path = None  # pretrain model path (for resume training or test)

    # Train Args
    image_size = 256  # image crop_size for training
    batch_size = 8
    data_root = '/mnt/Data1/ysc/18/coco'  # dataset root:$data_root/coco/a.jpg
    num_workers = 4  # dataloader num of workers

    lr = 1e-3
    epoches = 2  # total epoch to train
    content_weight = 1e5  # weight of content_loss
    style_weight = 1e10  # weight of style_loss

    style_path = '/mnt/Data1/ysc/18/style.jpg'  # style image path
    # env = 'neural-style'  # visdom env
    plot_every = 100  # visualize in visdom for every 10 batch

    # debug_file = '/tmp/debugnn'  # touch $debug_fie to interrupt and enter ipdb

    # Test Args
    content_path = '/mnt/Data1/ysc/18/COCO_train2014_000000000009.jpg'  # input file to do style transfer [for test]
    result_path = '/mnt/Data1/ysc/18/'  # style transfer result [for test]


def show(num):
    opt = Config()
    # input image preprocess
    content_image = tv.datasets.folder.default_loader(opt.content_path)
    content_transform = tv.transforms.Compose([
        tv.transforms.ToTensor(),
        tv.transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device).detach()
    # style transfer and save output
    transformer.eval()
    with torch.no_grad():
        output = transformer(content_image)
        output_data = output.cpu().data[0]
        tv.utils.save_image(((output_data / 255)).clamp(min=0, max=1), opt.result_path + '{}.jpg'.format(num))
    transformer.train()


def test(path1, path2):
    content_image = tv.datasets.folder.default_loader(path1)
    content_transform = tv.transforms.Compose([
        tv.transforms.ToTensor(),
        tv.transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device).detach()
    # style transfer and save output
    transformer.eval()
    with torch.no_grad():
        output = transformer(content_image)
        output_data = output.cpu().data[0]
        tv.utils.save_image(((output_data / 255)).clamp(min=0, max=1), path2)
    transformer.train()


if __name__ == '__main__':
    opt = Config()

    device = torch.device('cuda')

    # Data loading
    transfroms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Lambda(lambda x: x * 255)
    ])
    dataset = tv.datasets.ImageFolder(opt.data_root, transfroms)
    dataloader = data.DataLoader(dataset, opt.batch_size)

    vgg = Vgg16().eval()
    vgg.to(device)
    for param in vgg.parameters():
        param.requires_grad = False

    style = get_style_data(opt.style_path)       # 1*c*h*w
    # print(style.size())
    style = style.to(device)

    transformer = TransformerNet()
    transformer.to(device)
    # Optimizer
    optimizer = torch.optim.Adam(transformer.parameters(), opt.lr)

    # gram matrix for style image
    with torch.no_grad():
        features_style = vgg(style)
        gram_style = [gram_matrix(y) for y in features_style]

    content_losses = []
    style_losses = []

    # for epoch in range(opt.epoches):
    #     for ii, (x, _) in tqdm.tqdm(enumerate(dataloader)):
    #         # Train
    #         optimizer.zero_grad()
    #         x = x.to(device)
    #         y = transformer(x)
    #         y = normalize_batch(y)
    #         x = normalize_batch(x)
    #         features_y = vgg(y)
    #         features_x = vgg(x)
    #
    #         # content loss
    #         content_loss = opt.content_weight * F.mse_loss(features_y.relu2_2, features_x.relu2_2)
    #
    #         # style loss
    #         style_loss = 0.
    #         for ft_y, gm_s in zip(features_y, gram_style):
    #             gram_y = gram_matrix(ft_y)
    #             style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y))
    #         style_loss *= opt.style_weight
    #
    #         content_losses.append(content_loss.item())
    #         style_losses.append(style_loss.item())
    #
    #         total_loss = content_loss + style_loss
    #         total_loss.backward()
    #         optimizer.step()
    #
    #         if (ii + 1) % opt.plot_every == 0:
    #             show((ii + 1)/opt.plot_every)
    #
    # torch.save(transformer.state_dict(),
    #            '/mnt/Data1/ysc/18/model_best.pth')
    #
    # plt.figure(1)
    # plt.plot(content_losses)
    # plt.figure(2)
    # plt.plot(style_losses)
    # plt.show()

    transformer.load_state_dict(torch.load('/mnt/Data1/ysc/18/model_best.pth'))
    A = '/mnt/Data1/ysc/18/微信图片_20210611111932.jpg'
    B = '/mnt/Data1/ysc/18/test3.jpg'
    test(A, B)

小结

把原书的代码抄了一遍,数据集很大,效果看起来很不错的样子,具体需要理解的地方:

  • gram为什么可以这么定义以量化生成的图片和期望风格的差异
  • batch的均值和标准差的计算
  • 特征利用VGG网络怎么利用代码完成抽取
  • 风格转化网络怎么搭建的以及具体的实现(上采样、残差…)
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值