总共数据是202599张图片,在这里的分了训练集162770(比例0.8),验证集19867(比例0.098),测试集19962(比例0.096)
注意文件夹的创建
##base_path = './data'
def add_splits(base_path):
data_path = os.path.join(base_path, 'CelebA')
images_path = os.path.join(data_path, 'images')
train_dir = os.path.join(data_path, 'splits', 'train')
valid_dir = os.path.join(data_path, 'splits', 'valid')
test_dir = os.path.join(data_path, 'splits', 'test')
if not os.path.exists(train_dir):
os.makedirs(train_dir)
if not os.path.exists(valid_dir):
os.makedirs(valid_dir)
if not os.path.exists(test_dir):
os.makedirs(test_dir)
这里不用给的download.py来下载CelebA数据集,我自己已经准备好了,把CelebA数据集里的图片全放在images文件夹里。
if __name__ == '__main__':
base_path = './data'
prepare_data_dir()
##download_celeb_a(base_path)
add_splits(base_path)
把download_celeb_a(base_path)这行注释掉,再以管理员权限运行download.py
def check_link(in_dir, basename, out_dir):
in_file = os.path.join(in_dir, basename)
if os.path.exists(in_file):
link_file = os.path.join(out_dir, basename)
rel_link = os.path.relpath(in_file, out_dir)
os.symlink(rel_link, link_file)
建立软链接,然后运行main函数,开始训练。
中断后如何恢复训练:
添加<model’s name>_xxxx_xxxxxxx到–load_path,
–load_path=CelebA_0507_210503