Pixel2Pixel:人像卡通化
论文:Image-to-Image Translation with Conditional Adversarial Networks
论文链接:https://arxiv.org/abs/1611.07004
Pixel2Pixel:人像卡通化
准备工作:引入依赖 & 数据准备
import paddle
import paddle.nn as nn
from paddle.io import Dataset, DataLoader
import os
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import Sized
数据准备:
- 真人数据来自seeprettyface。
- 数据预处理(详情见photo2cartoon项目)。
- 使用photo2cartoon项目生成真人数据对应的卡通数据。
# 解压数据
!unzip -oq data/data79149/cartoon_A2B.zip -d data/
数据可视化
# 训练数据统计
train_names = os.listdir('data/cartoon_A2B/train')
print(f'训练集数据量: {len(train_names)}')
# 测试数据统计
test_names = os.listdir('data/cartoon_A2B/test')
print(f'测试集数据量: {len(test_names)}')
# 训练数据可视化
imgs = []
for img_name in np.random.choice(train_names, 3, replace=False):
img = cv2.imread('data/cartoon_A2B/train/'+img_name)
imgs.append(img)
print(img.shape)
img_show = np.vstack(imgs)[:,:,::-1]
plt.figure(figsize=(10, 10))
plt.imshow(img_show)
plt.show()
训练集数据量: 1361
测试集数据量: 100
(256, 512, 3)
(256, 512, 3)
(256, 512, 3)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
return list(data) if isinstance(data, collections.MappingView) else data
class PairedData(Dataset):
def __init__(self, phase):
super(PairedData, self).__init__()
self.img_path_list = self.load_A2B_data(phase) # 获取数据列表
self.num_samples = len(self.img_path_list) # 数据量
def __getitem__(self, idx):
img_A2B = cv2.imread(self.img_path_list[idx]) # 读取数据
img_A2B = img_A2B.astype('float32') / 127.5 - 1. # 归一化
img_A2B = img_A2B.transpose(2, 0, 1) # HWC -> CHW
img_A = img_A2B[..., :256] # 真人照
img_B = img_A2B[..., 256:] # 卡通图
return img_A, img_B
def __len__(self):
return self.num_samples
@staticmethod
def load_A2B_data(phase):
assert phase in ['train', 'test'], "phase should be set within ['train', 'test']"
# 读取数据集,数据中每张图像包含照片和对应的卡通画。
data_path = 'data/cartoon_A2B/'+phase
return [os.path.join(data_path, x) for x in os.listdir(data_path)]
paired_dataset_train = PairedData('train')
paired_dataset_test = PairedData('test')
第一步:搭建生成器
请大家补齐空白处的代码,‘#’ 后是提示。
原始代码输出模型尺寸打印:
class UnetGenerator(nn.Layer):
def __init__(self, input_nc=3, output_nc=3, ngf=64):
super(UnetGenerator, self).__init__()
self.down1 = nn.Conv2D(input_nc, ngf, kernel_size=4, stride=2, padding=1)
self.down2 = Downsample(ngf, ngf*2)
self.down3 = Downsample(ngf*2, ngf*4)
self.down4 = Downsample(ngf*4, ngf*8)
self.down5 = Downsample(ngf*8, ngf*8)
self.down6 = Downsample(ngf*8, ngf*8