python风格迁移_图片风格迁移(Neural style)

图片风格迁移训练时使用预处理的模型,不进行模型的训练,将图片像素值设置为可训练参数,进行图片的训练。

结合一张图片的内容和另一张图片的风格,生成一张新图片

up-a3c765e68cac23551f4bae3f18bcb26ded2.png

论文地址:https://arxiv.org/pdf/1508.06576.pdf

损失函数:

up-437e7226036a667d989ed59aa208f7d2167.png

up-09a61e036c45b1123229eec095beca6c772.png

以下代码通过将图片x输入到预训练好的vgg19模型中,取模型features模块的0, 5, 10, 19, 28五层的输出值计算contentLoss与StyleLoss之和,对该loss进行反向传播,从而修改图片的像素值。

"""

图片风格迁移,使用预训练的vgg,不训练vgg模型,而是进行图片的训练,训练我们想要的风格图片

"""

import torch

from PIL import Image

import argparse

import torch.nn as nn

import matplotlib.pyplot as plt

import numpy as np

import torchvision

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_image(image_path, transform=None, max_size=None, shape=None):

image = Image.open(image_path)

if max_size:

scale = max_size/max(image.size)

size = np.array(image.size)*scale

image = image.resize(size.astype(int), Image.ANTIALIAS)

if shape:

image = image.resize(shape, Image.LANCZOS)

if transform:

image = transform(image).unsqueeze(0)

return image.to(device)

transform = torchvision.transforms.Compose([

torchvision.transforms.ToTensor(),

# mean and std based on trainDatasets

torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

])

content = load_image("data/content.jpg", transform, max_size=400)

style = load_image("data/style.jpg", transform, shape=[content.size(2), content.size(3)])

print(content.shape) # torch.Size([1, 3, 400, 300])

print(style.shape) # torch.Size([1, 3, 300, 400])

def imshow(tensor, title=None):

image = tensor.cpu().clone()

image = image.squeeze(0)

# ToPILImage为类对象,因此需要ToPILImage()(image)

image = torchvision.transforms.ToPILImage()(image)

plt.imshow(image)

if title is not None:

plt.title(title)

plt.pause(2) # pause a bit so that plots are updated

class VGGNet(nn.Module):

def __init__(self):

super(VGGNet, self).__init__()

self.select = ['0', '5', '10', '19', '28']

self.vgg = torchvision.models.vgg19(pretrained=True).features

def forward(self, x):

features = []

for name, layer in self.vgg._modules.items():

x = layer(x)

if name in self.select:

features.append(x)

return features

# 不训练模型,eval()模式

vgg = VGGNet().to(device).eval()

for feat in vgg(content):

print(feat.shape)

# 训练图片

target = content.clone().requires_grad_(True)

print("target shape: ", target.shape)

optimizer = torch.optim.Adam([target], lr=0.003, betas=[0.5, 0.999])

num_steps = 2000

for step in range(num_steps):

target_features = vgg(target)

content_features = vgg(content)

style_features = vgg(style)

content_loss = style_loss = 0.

for f1, f2, f3 in zip(target_features, content_features, style_features):

content_loss += torch.mean((f1-f2)**2)

_, c, h, w = f1.size()

f1 = f1.view(c, h*w)

f3 = f3.view(c, h*w)

f1 = torch.mm(f1, f1.t())

f3 = torch.mm(f3, f3.t())

style_loss += torch.mean((f1-f3)**2)/(c*h*w)

loss = content_loss + style_loss*100.

optimizer.zero_grad()

loss.backward()

optimizer.step()

if step % 100 == 0:

print("Step [{}/{}, Content Loss: {:.4f}, Style Loss: {:.4f}"

.format(step, num_steps, content_loss.item(), style_loss.item()))

denorm = torchvision.transforms.Normalize([-2.12, -2.04, -1.8], [4.37, 4.46, 4.44])

img = target.clone().squeeze()

img = denorm(img).clamp_(0, 1)

imshow(img, title="Target Image")

plt.pause(100)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值