基于transfer-learning的风格迁移

       风格迁移就是让一张照片的内容(content)不变,但是风格(style)却和另一种照片接近。
       训练过程中的loss由两部分组成,一个是content_loss,表示target_pic与content_pic之间内容差别的大小;一个是style_loss,表示target_pic与style_pic之间风格差别的大小。在本例中,一开始,我们直接copy content_pic作为target_pic,在训练过程中,使target_pic的风格逐渐向style_pic靠近。在本例中,使用已经训练好的vgg19的部分层作为图片特征提取器。
       首先看看本例中的content_pic和style_pic:
content_pic
style_pic

from __future__ import division
from torchvision import models
from torchvision import transforms
from PIL import Image
import argparse
import torch
import torchvision
import torch.nn as nn
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_image(image_path, transform=None, max_size=None, shape=None):
    image = Image.open(image_path)
    if max_size: #如果规定了最大的size 那么就进行缩放
        scale = max_size / max(image.size)
        size= np.array(image.size) * scale
        image = image.resize(size.astype(int), Image.ANTIALIAS) #ANTIALIAS和下面的LANCZOS都表示图片的质量
         
    if shape:  #如果规定了shape,那么就按照shape进行裁剪
        image = image.resize(shape, Image.LANCZOS)
        
    if transform: # 如果设置了transform,那么就按照transform对图片进行转换,
                  # 并且转变完成后,在最前面增加一个维度,经过transform之后的size就是(1,channels_in,height,width)
        image = transform(image).unsqueeze(0)
        
    return image.to(device)


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])
content = load_image("./风格迁移/content.jpg", transform, max_size=400)
style = load_image("./风格迁移/style.jpg", transform, shape=[content.size(2), content.size(3)])
unloader = transforms.ToPILImage()  # reconvert into PIL image

plt.ion()

def imshow(tensor, title=None):
    image = tensor.cpu().clone()  # we clone the tensor to not do changes on it
    image = image.squeeze(0)      # remove the fake batch dimension
    image = unloader(image)
    plt.imshow(image)
    if title is not None:  #如果有标题,就把图片的标题显示出来
        plt.title(title)
    plt.pause(0.001) # pause a bit so that plots are updated


plt.figure()
imshow(content, title='Image')

在这里插入图片描述

vgg = models.vgg19(pretrained=True)
vgg
# 从打印出来的结果可以看出,vgg19整体上由3个部分组成:features、avgpool和classifier
# 根据前人的经验,vgg19的features部分中的0,5,10,19,28层对图片特征提取的效果较好
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace=True)
    (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): ReLU(inplace=True)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): ReLU(inplace=True)
    (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): ReLU(inplace=True)
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (33): ReLU(inplace=True)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): ReLU(inplace=True)
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)
class FeatureExtract_Net(nn.Module): #其实就是原始vgg的某几层
    def __init__(self):
        super(FeatureExtract, self).__init__()
        self.vgg = models.vgg19(pretrained=True).features
        self.select = ['0', '5', '10', '19', '28']
        
    def forward(self, x):
        features = []
        for name, layer in self.vgg._modules.items():
            # self.vgg._modules是一个dict,key是每个层的名字,value是具体的某个层,items就是把配对的key-value一对一对地搞成tuple
            x = layer(x) # x实际上是要经过vgg19 features中的每一层,但是我们只保留0、5、10、19、28层的输出结果
            if name in self.select:
                features.append(x)
        return features


target = content.clone().requires_grad_(True)  # 目标照片先完全复制content照片,之后在训练过程中,target照片的风格逐渐向style照片偏移
optimizer = torch.optim.Adam([target], lr=0.003, betas=[0.5, 0.999])
vgg = FeatureExtract_Net().to(device).eval() # 我们使用vgg19的部分层只是做特征提取,并不需要训练它,所以模型的模式是eval
target_features = vgg(target)

total_step = 2000  #一张照片总共要训练多少次,时间关系,只做了2000次
style_weight = 100. #style_loss的权重
for step in range(total_step):
    if hasattr(torch.cuda, 'empty_cache'):
        torch.cuda.empty_cache()

    target_features = vgg(target)
    content_features = vgg(content)
    # 由于target是完全复制content,所以训练一开始时,target_features和content_features是完全一样的
    style_features = vgg(style)
    
    style_loss = 0
    content_loss = 0
    for f1, f2, f3 in zip(target_features, content_features, style_features):
        content_loss += torch.mean((f1-f2)**2)
        _, c, h, w = f1.size() # f1的size为(batch_size,channels,height,width),这儿只有一张照片,所以batch_size为1
        f1 = f1.view(c, h*w) #把target_features的size变成(channels,height*width)  改变size的原因是为了后续计算gram matrix
        f3 = f3.view(c, h*w) #把style_features的size变成(channels,height*width) 理由同上
        
        # 计算gram matrix
        # 按照前人经验,gram matrix可以代表图片的风格
        f1_grammatrix = torch.mm(f1, f1.t())  #计算target picture的风格
        f3_grammatrix = torch.mm(f3, f3.t())  #计算style picture的风格
        style_loss += torch.mean((f1_grammatrix-f3_grammatrix)**2)/(c*h*w)  #比较它们之间的差异度
        
    loss = content_loss + style_weight * style_loss
    
    # 更新target
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if step % 10 == 0:
        print("Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}"
             .format(step, total_step, content_loss.item(), style_loss.item()))  
        
# 从下面的训练结果可以看出,content_loss在逐渐变大,style_loss在逐渐变小。
# 这是因为target最开始是完全复制content,所以开始训练之后,content_loss肯定会变大,
# 而随着训练过程的推进,target的风格越来越向style靠近,所以style_loss肯定会变小
Step [0/2000], Content Loss: 0.0000, Style Loss: 1597.5527
Step [10/2000], Content Loss: 8.7127, Style Loss: 1224.9143
Step [20/2000], Content Loss: 16.2456, Style Loss: 921.6854
Step [30/2000], Content Loss: 19.5701, Style Loss: 749.3969
Step [40/2000], Content Loss: 21.8312, Style Loss: 637.0518
Step [50/2000], Content Loss: 23.5768, Style Loss: 557.3438
... ...
... ...
Step [1950/2000], Content Loss: 40.1516, Style Loss: 19.5787
Step [1960/2000], Content Loss: 40.1714, Style Loss: 19.4750
Step [1970/2000], Content Loss: 40.1896, Style Loss: 19.3724
Step [1980/2000], Content Loss: 40.2087, Style Loss: 19.2703
Step [1990/2000], Content Loss: 40.2266, Style Loss: 19.1693
denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44)) #之间预处理图片时将它们标准化了,现在还原
img = target.clone().squeeze() #把batch_size那个维度squeeze掉
img = denorm(img).clamp_(0, 1) #denorm之后有些不在[0,1]范围内,因此把它们全部归一化
plt.figure()
imshow(img, title='Target Image') #看看训练出来的target_pic长什么样

在这里插入图片描述

a = transforms.ToPILImage()

a(img.cpu()).save('./final.jpg') #保存target_pic

在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值