基于pytorch框架下的partialconv图像修复代码的训练,代码链接如下:GitHub - tanimutomo/partialconv: Re-Implementation of "Image Inpainting for Irregular Holes using Partial Convolution"(非官方)
训练流程:
1.下载预训练权重
https://drive.google.com/file/d/1sooo-BLSNRUGWG_AB-lxh7xHgJ2bS29a/view#/点击链接下载预训练权重(places2数据集),放置在项目的./partialconv文件夹下 (此文件夹需自己新建)
2.下载places2官方数据集,并在default_config.yml中修改数据集文件路径
data_root: D:\User\desktop\strawberry-test # #!~/data
3.python experiment.py
报错:
把将os.mkdir()改为os.makedirs()
4.开始训练
训练以及可视化结果将保存到ckpt下子目录的models和val_vis中,其中模型保存和可视化结果保存的迭代次数设置在default_config中根据个人需求进行修改(如果用自己的数据集,则需要把batch size和学习率等参数进行调整)
5.训练自己的数据集
数据集文件夹按照以下组织形式:
mask和val_mask按需调整,本文代码中提供了随机生成mask的.py文件,即generate_mask,可以设置随机生成的图片数量。
6.开始训练:
7.注意事项
数据集图片数量如果过小,则训练结束次数将未达到规定的最大迭代次数,可使用图像增强代码进行数据增强扩充数据集。
(未完接下篇)