第六节 图片风格迁移
- 图片风格迁移
- 用GAN生成MNIST
- 用DCGAN生成更复杂的图片
matplotlib inline
from __future__ import division
from torchvision import models
from torchvision import transforms
from PIL import Image
import argparse
import torch
import torchvision
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_image(image_path, transform=None, max_size=None, shape=None):
image = Image.open(image_path)
if max_size:
scale = max_size / max(image.size)
size= np.array(image.size) * scale
image = image.resize(size.astype(int), Image.ANTIALIAS)
if shape:
image = image.resize(shape, Image.LANCZOS)
if transform:
image = transform(image).unsqueeze(0)
return image.to(device)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
content = load_image("png/content.png", transform, max_size=400)
stype = load_image("png/style.png", transform, shape=[content.size(2), content.size(3)])
stype.shape
unloader = transforms.ToPILImage()
plt.ion()
def imshow(tensor, title=None):
image = tensor.cpu().clone()
image = image.squeeze(0)
image = unloader(image)
plt.imshow(image)
if title is not None:
plt.title(title)
plt.pause(0.001)
plt.figure()
imshow(style[0], title='Image')
class VGGNet(nn.Module):
def __init__(self):
super(VGGNet, self).__init__()
self.select = ['0', '5', '10', '19', '28']
self.vgg = models.vgg19(pretrained=True).features
def forward(self, x):
features = []
for name, layer in self.vgg._modules.items():
x = layer(x)
if name in self.select:
features.append(x)
return features
target = content.clone().requires_grad_(True)
optimizer = torch.optim.Adam([target], lr=0.003, betas=[0.5, 0.999])
vgg = VGGNet().to(device).eval()
target_features = vgg(target)
total_step = 2000
style_weight = 100.
for step in range(total_step):
target_features = vgg(target)
content_features = vgg(content)
style_features = vgg(style)
style_loss = 0
content_loss = 0
for f1, f2, f3 in zip(target_features, content_features, style_features):
content_loss += torch.mean((f1-f2)