基本到了最后一课,主要内容是复现论文Image Style Transfer Using Convolutional Neural Networks的内容,这篇论文主要是希望将 图片的 content 和 另外一幅图片的style 合并在一起,生成如下的图片:
有兴趣的小伙伴可以去精读一下这篇论文,其实思路很简单,就是使用一个已经训练好的卷积神经网络(可以直接提取特征了),将 含有内容的图片 content image,输入到网络中,得到某些层的输出,作为 content feature,将另外一个含有风格的图片, style image, 输入到网络中,提取某些层的输出,获得其style_feature.。然后我们有一个target image,最开始初始化为白噪声(在程序里初始化为content image),然后也输入到网络中,得到对应层的输出,然后希望这些层的输出,content方面和content feature接近,风格方面和style_feature接近即可,然后不断对target image进行梯度下降求解,所以整个过程中,梯度下降的一直是这个target image,网络结构中的参数并没有改变。下面来进行具体的代码讲解。,代码都是重新手敲的,可能会有一些小问题,有问题看到会解决哈
首先是加载我们的模型,这里我们按照论文,使用vgg19网络中的feature部分:
import torch
import torch.nn as nn,models
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
model=models.vgg19(pretrained=True)
#如果下载速度比较慢,可以直接下载vgg19的模型,加载到模型上:
#model=models.vgg19()
#model.load_state_dict(torch.load('vgg19-dcbb9e9d.pth'))
#model.eval()
feature_extraction_net=model.features
#同时我们需要将网络中的参数锁定住:
for param in feature_extraction_net.parameters():
param.requires_grad=False
device=torch.device('cuda' if torch.cuda.is_avaliable() else 'cpu')
feature_extraction_net.to(device)
print(feature_extraction_net)
可以看到网络结构如下:
Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(17): ReLU(inplace=True)
(18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(24): ReLU(inplace=True)
(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(26): ReLU(inplace=True)
(27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(31): ReLU(inplace=True)
(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(33): ReLU(inplace=True)
(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(35): ReLU(inplace=True)
(36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
然后我们输入我们想要的content image 和style image
#从tensor到numpy的函数:
def im_Tensor2Np(tensor):
image=tensor.clone().cpu.detach().numpy()
image.sequeeze()
image=image.transpose(1,2,0)
image=image*np.array([0.5,0.5,0.5])+np.array([0.5,0.5,0.5])
image=image.clip(0,1)
return image
def load_image(image_path,max_size=400,shape=None):
image=Image.open(image_path).convert('RGB')
# 对image 进行 resize
if np.max(image.size)>max_size:
size=max_size
else:
size=np.max(image.size)
if shape is not None:
size=shape
transform=transform.Compose([transform.Resize(size),
transform.ToTensor(),
transform.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
image=transform(image).unsqueeze(0)#一定要记得加上batch size
return image
content_img=load_image('City.jpg').to(device)
style_img=load_image('City.jpg',shape=content_img.shape[-2:]).to(device)
#然后显示图片:
plt.sub_plot(1,2,1)
plt.imshow(im_Tensor2Np(content_img))
plt.axis('off')
plt.sub_plot(1,2,2)
plt.imshow(im_Tensor2Np(style_img))
plt.axis('off')
plt.show()
显示的图片如下图:
然后我们从多个网络层中输出特征图,同时需要注意的一点是,计算style和content不同,style需要计算一个gram 矩阵,方法非常简单
#除了conv4_2为content layer,其他的均为 style layer,字典前面的key,是这些层在feature_extraction_net中Sequential的编号
feature_layer={'0':'conv1_1',
'5':'conv2_1',
'19':'conv3_1',
'21':'conv4_1',
'22':'conv4_2', #content_layer
'28':'conv5_1'}
# style feature不同层直接是有权重的:
style_weight={'conv1_1':1.0,
'conv2_1':0.75,
'conv3_1':0.2,
'conv4_1':0.2,
'conv5_1':0.2}
#下面提取每一层的输出:
def get_features(model,layers,image):
features={}
for i in range(len(model)):
image=model[i](image) #note:Sequential只能用index去索引,即使在model中使用OrderDict添加了名字,也不能用名字索引
if str(i) in layers:
features[layers[str(i)]]=image
return features
def gram_matrix(tensor):
tensor_copy=tensor.clone()
_,d,w,h=tensor_copy.size()
tensor_copy=tensor_copy.view(d,w*h)
gram=torch.mm(tensor_copy,tensor_copy.t())
return gram
content_features=get_features(feature_extraction_net,feature_layer,content_img)
style_features=get_features(feature_extraction_net,feature_layer,style_img)
style_gram={layers: gram_matrix(feature_layer[layers]) for layers in style_weight}
得到了 content feature 和style_gram
我们就可以根据这些特征,对target image进行更新啦
target_img=content_img.clone().to(device)
target_img.requires_grad=True #将目标图片的 grad打开
alpha=1.0;beta=1e6 # content 和 style的权重, 后面求loss 会用
epoches=20000
show_every=3000;
optimizer=torch.optim.Adam([target_img],lr=0.003)
for i in range(epoches):
target_features=get_features(feature_extraction_net,feature_layer,target_img)
content_loss=torch.mean((target_features['conv4_2']-content_features['conv4_2'])**2)
style_loss=0
for layer in style_weight:
target_feature=target_features[layer]
target_gram=gram_matrix(target_feature)
style_gram=style_grams[layer]
layer_style_loss=style_grams[layer]*style_weight[layer]
_,d,w,h=target_feature.size()
style_loss+=layer_style_loss/(d*w*h)
total_loss=content_loss*alpha + beta*style_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
if i%show_every==0:
print('iteration:'i,'loss:',total_loss.item())
plt.close()
plt.imshow(im_Tensor2Np(target_img))
plt.show()
可以看到target_img不断地变化