风格迁移就是让一张照片的内容(content)不变,但是风格(style)却和另一种照片接近。
训练过程中的loss由两部分组成,一个是content_loss,表示target_pic与content_pic之间内容差别的大小;一个是style_loss,表示target_pic与style_pic之间风格差别的大小。在本例中,一开始,我们直接copy content_pic作为target_pic,在训练过程中,使target_pic的风格逐渐向style_pic靠近。在本例中,使用已经训练好的vgg19的部分层作为图片特征提取器。
首先看看本例中的content_pic和style_pic:
from __future__ import division
from torchvision import models
from torchvision import transforms
from PIL import Image
import argparse
import torch
import torchvision
import torch.nn as nn
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_image(image_path, transform=None, max_size=None, shape=None):
image = Image.open(image_path)
if max_size: #如果规定了最大的size 那么就进行缩放
scale = max_size / max(image.size)
size= np.array(image.size) * scale
image = image.resize(size.astype(int), Image.ANTIALIAS) #ANTIALIAS和下面的LANCZOS都表示图片的质量
if shape: #如果规定了shape,那么就按照shape进行裁剪
image = image.resize(shape, Image.LANCZOS)
if transform: # 如果设置了transform,那么就按照transform对图片进行转换,
# 并且转变完成后,在最前面增加一个维度,经过transform之后的size就是(1,channels_in,height,width)
image = transform(image).unsqueeze(0)
return image.to(device)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])
content = load_image("./风格迁移/content.jpg", transform, max_size=400)
style = load_image("./风格迁移/style.jpg", transform, shape=[content.size(2), content.size(3)])
unloader = transforms.ToPILImage() # reconvert into PIL image
plt.ion()
def imshow(tensor, title=None):
image = tensor.cpu().clone() # we clone the tensor to not do changes on it
image = image.squeeze(0) # remove the fake batch dimension
image = unloader(image)
plt.imshow(image)
if title is not None: #如果有标题,就把图片的标题显示出来
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
plt.figure()
imshow(content, title='Image')
vgg = models.vgg19(pretrained=True)
vgg
# 从打印出来的结果可以看出,vgg19整体上由3个部分组成:features、avgpool和classifier
# 根据前人的经验,vgg19的features部分中的0,5,10,19,28层对图片特征提取的效果较好
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(17): ReLU(inplace=True)
(18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(24): ReLU(inplace=True)
(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(26): ReLU(inplace=True)
(27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(31): ReLU(inplace=True)
(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(33): ReLU(inplace=True)
(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(35): ReLU(inplace=True)
(36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
class FeatureExtract_Net(nn.Module): #其实就是原始vgg的某几层
def __init__(self):
super(FeatureExtract, self).__init__()
self.vgg = models.vgg19(pretrained=True).features
self.select = ['0', '5', '10', '19', '28']
def forward(self, x):
features = []
for name, layer in self.vgg._modules.items():
# self.vgg._modules是一个dict,key是每个层的名字,value是具体的某个层,items就是把配对的key-value一对一对地搞成tuple
x = layer(x) # x实际上是要经过vgg19 features中的每一层,但是我们只保留0、5、10、19、28层的输出结果
if name in self.select:
features.append(x)
return features
target = content.clone().requires_grad_(True) # 目标照片先完全复制content照片,之后在训练过程中,target照片的风格逐渐向style照片偏移
optimizer = torch.optim.Adam([target], lr=0.003, betas=[0.5, 0.999])
vgg = FeatureExtract_Net().to(device).eval() # 我们使用vgg19的部分层只是做特征提取,并不需要训练它,所以模型的模式是eval
target_features = vgg(target)
total_step = 2000 #一张照片总共要训练多少次,时间关系,只做了2000次
style_weight = 100. #style_loss的权重
for step in range(total_step):
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
target_features = vgg(target)
content_features = vgg(content)
# 由于target是完全复制content,所以训练一开始时,target_features和content_features是完全一样的
style_features = vgg(style)
style_loss = 0
content_loss = 0
for f1, f2, f3 in zip(target_features, content_features, style_features):
content_loss += torch.mean((f1-f2)**2)
_, c, h, w = f1.size() # f1的size为(batch_size,channels,height,width),这儿只有一张照片,所以batch_size为1
f1 = f1.view(c, h*w) #把target_features的size变成(channels,height*width) 改变size的原因是为了后续计算gram matrix
f3 = f3.view(c, h*w) #把style_features的size变成(channels,height*width) 理由同上
# 计算gram matrix
# 按照前人经验,gram matrix可以代表图片的风格
f1_grammatrix = torch.mm(f1, f1.t()) #计算target picture的风格
f3_grammatrix = torch.mm(f3, f3.t()) #计算style picture的风格
style_loss += torch.mean((f1_grammatrix-f3_grammatrix)**2)/(c*h*w) #比较它们之间的差异度
loss = content_loss + style_weight * style_loss
# 更新target
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 10 == 0:
print("Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}"
.format(step, total_step, content_loss.item(), style_loss.item()))
# 从下面的训练结果可以看出,content_loss在逐渐变大,style_loss在逐渐变小。
# 这是因为target最开始是完全复制content,所以开始训练之后,content_loss肯定会变大,
# 而随着训练过程的推进,target的风格越来越向style靠近,所以style_loss肯定会变小
Step [0/2000], Content Loss: 0.0000, Style Loss: 1597.5527
Step [10/2000], Content Loss: 8.7127, Style Loss: 1224.9143
Step [20/2000], Content Loss: 16.2456, Style Loss: 921.6854
Step [30/2000], Content Loss: 19.5701, Style Loss: 749.3969
Step [40/2000], Content Loss: 21.8312, Style Loss: 637.0518
Step [50/2000], Content Loss: 23.5768, Style Loss: 557.3438
... ...
... ...
Step [1950/2000], Content Loss: 40.1516, Style Loss: 19.5787
Step [1960/2000], Content Loss: 40.1714, Style Loss: 19.4750
Step [1970/2000], Content Loss: 40.1896, Style Loss: 19.3724
Step [1980/2000], Content Loss: 40.2087, Style Loss: 19.2703
Step [1990/2000], Content Loss: 40.2266, Style Loss: 19.1693
denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44)) #之间预处理图片时将它们标准化了,现在还原
img = target.clone().squeeze() #把batch_size那个维度squeeze掉
img = denorm(img).clamp_(0, 1) #denorm之后有些不在[0,1]范围内,因此把它们全部归一化
plt.figure()
imshow(img, title='Target Image') #看看训练出来的target_pic长什么样
a = transforms.ToPILImage()
a(img.cpu()).save('./final.jpg') #保存target_pic