使用stylegan2训练自己的数据集
官方开源链接
链接: link.
对数据集进行处理
数据集resize图片尺寸
// resize image
from glob import glob
from PIL import Image
import os
from tqdm import tqdm
from tqdm._tqdm import trange
img_path = glob("./resize/*.png")
path_save = "./resize/"
a = range(0, len(img_path))
i = 0
for file in tqdm(img_path):
name = os.path.join(path_save, "%d.png" % a[i])
im = Image.open(file)
im.thumbnail((1024, 1024))
print(im.format, im.size, im.mode)
im.save(name, 'png')
i += 1
生成数据集对应的tfrecords格式
// 第一个目录参数为tfrecords格式存放的目录,第二个目录参数为resize后images图片路径
python dataset_tool.py create_from_images ~/datasets/my-custom-dataset ~/my-custom-images
//可视化数据集
python dataset_tool.py display ~/datasets/my-custom-dataset
训练
// config文件分为f和e,对应用不同的显存大小训练
python run_training.py --num-gpus=1 --data-dir=datasets --config=config-e --dataset=custome_dataset1 --mirror-augment=true
测试
// seeds为生成的照片索引,可以取多个值
# Generate 1000 random images without truncation
python run_generator.py generate-images --seeds=0-999 --truncation-psi=1.0 --network=results/00006-stylegan2-ffhq-8gpu-config-f/networks-final.pkl
#example
python run_generator.py generate-images --seeds=9,66,286 --truncation-psi=1.0 --network=results/00007-stylegan2-custome_dataset-1gpu-config-e/network-snapshot-001323.pkl