【生成式网络】入门篇(三):Style Transfer 的 代码和结果记录

Style Transfer 记录

经典文章xxx,理论就不介绍了,根据一个content图像,和一个style图像,可以把style图像的style迁移到content图像上。
在代码上有一个跟之前不同的地方,就是这里需要不断优化的变量是这张图像,vgg只是用来提取特征,不需要反传。具体做的时候,把三张图(content 图, style图,和我们希望生成的target图)都通过vgg,提取中间某些层计算出feature,这些feature 之间会计算一个content loss,使得target图和content图内容接近,同时计算一个style loss,使得target图和style图的style接近。

一般为了加速迭代,target会先用content图初始化。

不废话,上代码

import os
os.chdir(os.path.dirname(__file__))
from torchvision import models
from torchvision import transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import argparse
from PIL import Image
from torch.utils.tensorboard import SummaryWriter

sample_dir = 'samples_style_transfer'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir, exist_ok=True)

writer = SummaryWriter(sample_dir)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_image(image_path, transform=None, max_size=None, shape=None):
    image = Image.open(image_path)

    if max_size:
        scale = max_size / max(image.size)
        size = np.array(image.size) * scale
        image = image.resize(size.astype(int), Image.ANTIALIAS)
    
    if shape:
        image = image.resize(shape, Image.LANCZOS)

    if transform:
        image = transform(image).unsqueeze(0)

    return image.to(device)

class VGGNet(nn.Module):
    def __init__(self):
        super(VGGNet, self).__init__()
        self.select = ['0', '5', '10', '19', '28'] 
        self.vgg = models.vgg19(pretrained=True).features
    
    def forward(self, x):
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.select:
                features.append(x)
        return features



def main(config):
    T = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.485, 0.456, 0.406), 
            std=(0.229, 0.224, 0.225))
    ])

    content = load_image(config.content, T, max_size=config.max_size)
    style = load_image(config.style, T, shape=[content.size(2), content.size(3)])

    target = content.clone().requires_grad_(True)

    optimizer = torch.optim.Adam([target], lr=config.lr, betas=[0.5, 0.999])
    vgg = VGGNet().to(device).eval()

    for epoch in range(config.total_step):
        target_feature = vgg(target)
        content_feature = vgg(content)
        style_feature = vgg(style)

        style_loss = 0
        content_loss = 0
        for f1, f2, f3 in zip(target_feature, content_feature, style_feature):
            content_loss += torch.mean((f1-f2)**2)

            _,c,h,w = f1.size()
            f1 = f1.view(c, h*w)
            f3 = f3.view(c, h*w)

            # gram matrix
            f1 = torch.mm(f1, f1.t())
            f3 = torch.mm(f3, f3.t())

            # style loss
            style_loss += torch.mean((f1-f3)**2) / (c*h*w)

        loss = content_loss + config.style_weight * style_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        writer.add_scalar('loss', loss.item(), global_step=epoch)
        writer.add_scalar('content_loss', content_loss.item(), global_step=epoch)
        writer.add_scalar('style_loss', style_loss.item(), global_step=epoch)

        if (epoch+1) % config.log_step == 0:
            print('Epoch [{}/{}], Loss: {:.4f}, Content loss: {:.4f}, Style loss: {:.4f}'.\
                format(epoch, config.total_step, loss.item(), content_loss.item(), style_loss.item()))

        if (epoch+1) % config.sample_step == 0:
            denorm = transforms.Normalize(mean=(-2.12, -2.04, -1.80), std=(4.37, 4.46, 4.44))
            img = target.clone().squeeze()
            img = denorm(img).clamp_(0, 1)
            writer.add_image('img', img, global_step=epoch)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--content', type=str, default='data/content.png')
    parser.add_argument('--style', type=str, default='data/style.png')
    parser.add_argument('--max_size', type=int, default=400)
    parser.add_argument('--total_step', type=int, default=2000)
    parser.add_argument('--log_step', type=int, default=10)
    parser.add_argument('--sample_step', type=int, default=500)
    parser.add_argument('--style_weight', type=float, default=100)
    parser.add_argument('--lr', type=float, default=0.003)

    config = parser.parse_args()
    config.total_step = 20000
    config.sample_step = 100

    print(config)
    main(config)
  • content图
    请添加图片描述

  • style图
    请添加图片描述

  • 接近4000次迭代
    在这里插入图片描述

  • 经过 20000次迭代,能够看到风格越来越接近了。
    在这里插入图片描述
    但是这种style transfer有一个缺点,就是得针对一张图像进行不断迭代,能不能来了一张新图像,送进去后很快就能得到新的图像呢?

当然有,那就是fast neural style transfer

Fast Style Transfer 记录

相关原理可以参考 https://blog.csdn.net/qq_33590958/article/details/96122789
相当于设计了一个transform 的网络,专门做风格转换,然后继续保留预训练好的vgg网络用来提取特征做loss,整体loss还是参考了原来的文章,有了这个transform网络,只需要训练好这个网络,新的图像输进去,一个前向过程就输出了转换后的图像,速度很快。不过也有个缺陷,那就是每个新的style图,都需要重新训练网络(我的理解是这样,有错的话欢迎指出)。
在这里插入图片描述
具体transform的网络结构如下,是全连接网络,大致分三个阶段,下采样-> 残差模块->上采样模块
在这里插入图片描述

代码是参考了pytorch的example代码,如下

import os
# os.chdir(os.path.dirname(__file__))
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision import datasets
from torchvision import models
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image
import argparse

sample_dir = 'samples_fast_style_transfer'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir, exist_ok=True)

writer = SummaryWriter(sample_dir)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
np.random.seed(0)
torch.manual_seed(0)

def load_image(filename, size=None, scale=None):
    img = Image.open(filename).convert('RGB')
    if size is not None:
        img = img.resize((size, size), Image.ANTIALIAS)
    elif scale is not None:
        size = (int(img.size[0] / scale), int(img.size[1] / scale))
        img = img.resize(size, Image.ANTIALIAS)
    return img

def save_image(filename, data):
    img = data.clone.clamp(0, 255).numpy()
    img = img.transpose(1,2,0).astype('uint8')
    img = Image.fromarray(img)
    img.save(filename)

def gram_matrix(y):
    b,ch,h,w = y.size()
    features = y.view(b,ch,h*w)
    features_t = features.transpose(1,2)
    gram = features.bmm(features_t)/(ch*h*w)
    return gram

def normalize_batch(batch):
    # normalize using imagenet mean and std
    mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
    std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
    batch = batch.div_(255.0)
    return (batch - mean) / std

def save_checkpoint(model, epochs):
    model.eval().cpu()
    save_model_filename = "epoch_" + str(epochs) + ".model"
    save_model_path = os.path.join(sample_dir, save_model_filename)
    torch.save(model.state_dict(), save_model_path)
    print("\nDone, trained model saved at", save_model_path)

class ConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out

class ResidualBlock(torch.nn.Module):
    """ResidualBlock
    introduced in: https://arxiv.org/abs/1512.03385
    recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
    """

    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.relu(self.in1(self.conv1(x)))
        out = self.in2(self.conv2(out))
        out = out + residual
        return out


class UpsampleConvLayer(torch.nn.Module):
    """UpsampleConvLayer
    Upsamples the input and then does a convolution. This method gives better results
    compared to ConvTranspose2d.
    ref: http://distill.pub/2016/deconv-checkerboard/
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__()
        self.upsample = upsample
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding) # 可以重点关注这个padding方式,参考https://blog.csdn.net/LionZYT/article/details/120181586
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        x_in = x
        if self.upsample:
            x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
        out = self.reflection_pad(x_in)
        out = self.conv2d(out)
        return out

class TransformerNet(torch.nn.Module):
    def __init__(self):
        super(TransformerNet, self).__init__()
        # Initial convolution layers
        self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
        self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
        self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
        self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
        self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
        # Residual layers
        self.res1 = ResidualBlock(128)
        self.res2 = ResidualBlock(128)
        self.res3 = ResidualBlock(128)
        self.res4 = ResidualBlock(128)
        self.res5 = ResidualBlock(128)
        # Upsampling Layers
        self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
        self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
        self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
        self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
        self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
        # Non-linearities
        self.relu = torch.nn.ReLU()

    def forward(self, X):
        y = self.relu(self.in1(self.conv1(X)))
        y = self.relu(self.in2(self.conv2(y)))
        y = self.relu(self.in3(self.conv3(y)))
        y = self.res1(y)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.relu(self.in4(self.deconv1(y)))
        y = self.relu(self.in5(self.deconv2(y)))
        y = self.deconv3(y)
        return y

from collections import namedtuple
class Vgg16(nn.Module):
    def __init__(self, required_grad=False):
        super(Vgg16, self).__init__()
        vgg_pretrained_features = models.vgg16(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        for x in range(4):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        if not required_grad:
            for param in self.parameters():
                param.requires_grad_(False)

    def forward(self, X):
        h = self.slice1(X)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3)
        return out

def train():
    epochs = 10
    batch_size = 4
    image_size = 256
    learning_rate = 1e-3
    style_weight = 1e6
    content_weight = 1e1
    log_step = 10

    dataset_path = 'data/fast_neural_style/dataset'
    style_image = 'data/fast_neural_style/style-images/mosaic.jpg'
    content_image = 'data/fast_neural_style/content-images/amber.jpg'
    
    T = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x : x.mul(255))
    ])

    train_dataset = datasets.ImageFolder(dataset_path, T)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, drop_last=True, shuffle=True)

    transformer = TransformerNet().to(device)
    optimizer = torch.optim.Adam(transformer.parameters(), lr=learning_rate)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(required_grad=False).to(device)

    # load style image
    style_T = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x :x.mul(255))
    ])
    style = load_image(style_image, size = image_size)
    style = style_T(style)
    style = style.repeat(batch_size, 1,1,1).to(device)

    features_style = vgg(normalize_batch(style))
    gram_style = [gram_matrix(y) for y in features_style]

    cnt = 0

    # load content image
    content_image = load_image(content_image)
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    for epoch in range(epochs):
        transformer.train()
        for batchid, (x, _) in enumerate(train_loader):

            x = x.to(device)
            y = transformer(x)

            # 这里是为了让数据的均值和方差符合预训练模型的分布
            x = normalize_batch(x)
            y = normalize_batch(y.mul(255))

            features_x = vgg(x)
            features_y = vgg(y)

            content_loss = mse_loss(features_y.relu2_2, features_x.relu2_2) * content_weight

            style_loss = 0
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s)

            style_loss = style_loss * style_weight

            total_loss = content_loss  + style_loss 
            
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            cnt += 1
            if cnt % log_step == 0:
                print('Epoch [{}/{}], Step [{}], Loss: {:.4f}, Content loss: {:.4f}, Style loss: {:.4f}'.\
                    format(epoch, epochs, cnt, total_loss.item(), content_loss.item(), style_loss.item()))

                writer.add_scalar('loss', total_loss.item(), global_step=cnt)
                writer.add_scalar('content_loss', content_loss.item(), global_step=cnt)
                writer.add_scalar('style_loss', style_loss.item(), global_step=cnt)

            if cnt % 100 == 0:
                img = eval(content_image, transformer)
                writer.add_image('target_images', img, global_step=cnt, dataformats='CHW')

    save_checkpoint(transformer, epoch)

def eval(content_image, transformer):
    transformer.eval()
    output_image = transformer(content_image).cpu()
    transformer.train()
    mean = output_image.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
    std = output_image.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)

    output_image = output_image * std + mean
    return output_image[0].clamp(0, 1)

if __name__ == '__main__':
    train()

重点可以看看这段代码,按照代码的解释,这里上采样不是通过convTranspose来实现,而是先做一个ReflectionPad,再进行插值来上采样,然后做一个conv,按照注释的说法,这样效果比ConvTranspose2d好。

class UpsampleConvLayer(torch.nn.Module):
    """UpsampleConvLayer
    Upsamples the input and then does a convolution. This method gives better results
    compared to ConvTranspose2d.
    ref: http://distill.pub/2016/deconv-checkerboard/
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__()
        self.upsample = upsample
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding) # 可以重点关注这个padding方式,参考https://blog.csdn.net/LionZYT/article/details/120181586
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        x_in = x
        if self.upsample:
            x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
        out = self.reflection_pad(x_in)
        out = self.conv2d(out)
        return out

ReflectionPad 可以参考https://blog.csdn.net/LionZYT/article/details/120181586
函数用途:对输入图像以最外围像素为对称轴,做四周的轴对称镜像填充。

在这里插入图片描述
效果如下

  • 原图content图请添加图片描述
  • style 图
    请添加图片描述
  • 最终风格迁移后的效果图

我个人觉得风格也没那么接近,但是有点那个意思了,在训练过程中也是变得越来越清晰的。
在这里插入图片描述
这里我也贴出最后官方给的output的效果

  • 官方训练的结果,看着好多了
    请添加图片描述

这份代码里 style_weight 和content_weight 给的奇高,不知道是为了什么,我稍作了调整,可能影响了最终的训练,大家可以自行再调整一下。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
在神经网络的训练过程中,通常需要进行以下五个步骤:准备数据、定义模型、定义损失函数、定义优化器、开始训练。下面是一份使用PyTorch实现style transfer代码,其中与这五个步骤相对应的代码部分已经用注释标出。 ```python import torch import torch.nn as nn import torchvision.transforms as transforms import torchvision.models as models from PIL import Image # 准备数据 transform = transforms.Compose([ transforms.Resize(512), # 调整图像大小 transforms.ToTensor(), # 将图像转换为Tensor transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化图像 ]) # 定义模型 class VGG(nn.Module): def __init__(self): super(VGG, self).__init__() self.features = models.vgg19(pretrained=True).features[:35] # 选择VGG19模型的前35层作为特征提取器 def forward(self, x): return self.features(x) # 定义损失函数 class StyleLoss(nn.Module): def __init__(self, target_feature): super(StyleLoss, self).__init__() self.target = self.gram_matrix(target_feature).detach() def forward(self, input): G = self.gram_matrix(input) self.loss = nn.functional.mse_loss(G, self.target) return input def gram_matrix(self, input): a, b, c, d = input.size() features = input.view(a * b, c * d) G = torch.mm(features, features.t()) return G.div(a * b * c * d) # 定义优化器 def get_input_optimizer(input_img): optimizer = torch.optim.Adam([input_img.requires_grad_()]) return optimizer # 开始训练 def run_style_transfer(content_img, style_img, num_steps=300, style_weight=1000000, content_weight=1): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 转换图像并将其放到设备上 content = transform(Image.open(content_img)).unsqueeze(0).to(device) style = transform(Image.open(style_img)).unsqueeze(0).to(device) input_img = content.clone().to(device).requires_grad_() # 定义模型和损失函数 model = VGG().to(device).eval() content_loss = nn.functional.mse_loss style_loss = StyleLoss(model(style).to(device)) # 定义优化器 optimizer = get_input_optimizer(input_img) # 迭代训练 for i in range(num_steps): input_img.data.clamp_(0, 1) optimizer.zero_grad() content_feature = model(content).detach() style_feature = model(input_img) content_loss = content_weight * content_loss(style_feature, content_feature) style_loss = 0 for ft, w in zip(style_feature, style_weight): style_loss += w * style_loss(ft, style_loss) loss = content_loss + style_loss loss.backward() optimizer.step() return input_img ``` 其中, - 准备数据:使用transforms定义了一组图像预处理方法,包括调整图像大小、将图像转换为Tensor、标准化图像。 - 定义模型:定义了一个VGG类,选择VGG19模型的前35层作为特征提取器。 - 定义损失函数:定义了一个StyleLoss类,用于计算风格损失。 - 定义优化器:定义了一个get_input_optimizer函数,用于获取一个Adam优化器。 - 开始训练:使用run_style_transfer函数开始训练,其中包括将图像转换到设备上、定义模型和损失函数、定义优化器、迭代训练过程。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值