图片类型迁移
from __future__ import division
from torchvision import models, transforms
from PIL import Image
import argparse
import torch
import torchvision
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def load_image(image_path, transform=None, max_size=None, shape=None):
image = Image.open(image_path)
if max_size:
scale = max_size / max(image.size)
size = np.array(image.size) * scale
image = image.resize(size.astype(int), Image.ANTIALIAS)
if shape:
image = image.resize(shape, Image.LANCZOS)
if transform:
image = transform(image).unsqueeze(0)
return image.to(device)
content = load_image("image/content.jpg", transform, max_size=400)
style = load_image("image/style.jpg", transform, shape=[content.size(2), content.size(3)])
class VGGNet(nn.Module):
def __init__(self):
super(VGGNet, self).__init__()
self.select = ['0', '5', '10', '19', '28']
self.vgg = models.vgg19(pretrained=True).features
def forward(self, x):
features = []
for name, layer in self.vgg._modules.items():
x = layer(x)
if name in self.select:
features.append(x)
return features
vgg = VGGNet().to(device).eval()
features = vgg(content)
for feat in features:
print(feat.shape)
target = content.clone().requires_grad_(True)
optimizer = torch.optim.Adam([target], lr=0.003, betas=[0.5, 0.999])
num_steps = 2000
for step in range(num_steps):
target_features = vgg(target)
content_features = vgg(content)
style_features = vgg(style)
content_loss = style_loss = 0.
for f1, f2, f3 in zip(target_features, content_features, style_features):
content_loss += torch.mean((f1-f2)**2)
_, c, h, w = f1.size()
f1 = f1.view(c, h*w)
f3 = f3.view(c, h*w)
f1 = torch.mm(f1, f1.t())
f3 = torch.mm(f3, f3.t())
style_loss += torch.mean((f1-f3)**2) / (c*h*w)
loss = content_loss + style_loss * 100
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 100 == 0:
print("Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}, Total Loss: {:.4f}"
.format(step, num_steps, content_loss, style_loss, loss))