pytorch学习10-图片风格迁移

图片风格迁移

from torchvision import models,transforms
from PIL import Image
import torch
import torchvision
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_image(image_path,transform=None,max_size=None,shape=None):
    image = Image.open(image_path).convert('RGB')
    print(image.size)
    if max_size:
        scale = max_size/max(image.size)
        size = np.array(image.size) * scale
        image = image.resize(size.astype(int),Image.ANTIALIAS)
        #size.astype(int) 如果上一步有小数,就变成整型。
        #设定ANTIALIAS,即抗锯齿,也译为边缘柔化、消除混叠、抗图像折叠有损等
    if shape:
        image = image.resize(shape,Image.LANCZOS)
    print("image.size:",image.size)
    if transform:
        image = transform(image).unsqueeze(0)
        #unsqueeze(0)在开头插入1维
    print("image.shape:",image.shape)
    return image.to(device)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])
]) # 来自ImageNet的mean和variance

content = load_image("content.png",transform,max_size=400,shape=[400, 326])
style = load_image("style.png", transform, shape=[content.size(2), content.size(3)])

在这里插入图片描述

unloader = transforms.ToPILImage()# reconvert into PIL image
# 这里转换成image格式看图片
plt.ion() #交互模式,不要也可以

def imshow(tensor,title=None):
    image = tensor.cpu().clone()
    image = unloader(image)
    plt.imshow(image)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)
plt.figure()
imshow(style[0],title="Image")
imshow(content[0],title="Image")

在这里插入图片描述

class VGGNet(nn.Module):
    def __init__(self):
        super(VGGNet,self).__init__()
        self.select = ['0','5','10','19','28']
        self.vgg = models.vgg19(pretrained=True).features
    def forward(self,x):
        features = []
        for name,layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.select:
                features.append(x)
        return features
target = content.clone().requires_grad_(True) #加这步说明target自身的像素参数需要更新
optimizer = torch.optim.Adam([target],lr=0.003,betas=[0.5,0.999])
vgg = VGGNet().to(device).eval()
target_features = vgg(target)
for i,feature in enumerate(target_features):
    print(feature.size())

在这里插入图片描述

total_step = 2000
style_weight = 100.
for step in range(total_step):
    target_features = vgg(target)
    content_features = vgg(content)
    style_features = vgg(style)
    
    style_loss = 0
    content_loss = 0
    for f1,f2,f3 in zip(target_features,content_features,style_features):
        content_loss += torch.mean((f1-f2)**2) #内容损失公式
        _,c,h,w = f1.size()
        f1 = f1.view(c,h*w)#图片宽高维度乘积后展开
        f3 = f3.view(c,h*w)
        
        # 计算gram matrix,矩阵相乘,f1.t()矩阵转置
        f1 = torch.mm(f1,f1.t())
        f3 = torch.mm(f3,f3.t())
        style_loss += torch.mean((f1-f3)**2)/(c*h*w) #风格损失
    loss = content_loss + style_weight * style_loss 
    #总loss=权重1*内容损失+权重2*风格损失,这里权重1没有
    
    # 更新target
    optimizer.zero_grad()
    loss.backward()
    #因为target设置了requires_grad_(True),所以只对target自身的像素进行更新
    optimizer.step()
    
    if step%10==0:
        print("Step[{}/{}],Content Loss:{:.4f},Style Loss:{:.4f}"
             .format(step,total_step,content_loss.item(),style_loss.item()))

在这里插入图片描述

denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))
img = target.clone().squeeze()
img = denorm(img).clamp_(0,1) #把小于0,变成0,大于1,变成1
plt.figure()
imshow(img,title="Target Image")

在这里插入图片描述

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值