前言
复现stargan的论文需要的材料。celeba数据库、128和256分辨下网络权重。
笔者已将它们全部下载并传送到某盘。
链接:https://pan.baidu.com/s/1mtMlxiaucyRwMWn9IXopRg 提取码:55fu
stargan论文:https://arxiv.org/abs/1711.09020
官方代码:https://github.com/yunjey/stargan
celeba-128x128-5attrs.zip
celeba-256x256-5attrs.zip
celeba.zip
图 文件列表
安装依赖环境
该项目依赖
- Python 3.5+
- PyTorch 0.4.0+
- TensorFlow 1.3+ (optional for tensorboard)
conda虚拟环境脚本
conda create -p env_torch python=3.6 -y
source activate env_torch/
conda install pytorch=1.1.0 torchvision=0.3.0 cudatoolkit=9.0 cudnn=7.6.5 tensorflow-gpu=1.14.0 -y
可能问题
无法使用cuda。请检查驱动版本,cuda ,cudnn,tensorflow,pytorch之间的版本对应关系。刚开始可能只能用cpu。即,
>>> import torch
>>> torch.cuda.is_available()
False
>>> torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device(type='cpu')
检查环境
(env_torch) user$ nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2016 NVIDIA Corporation
Built on Tue_Jan_10_13:22:03_CST_2017
Cuda compilation tools, release 8.0, V8.0.61
(env_torch) user$ nvidia-smi -l
图 驱动版本号
版本对应关系
图 cuda与驱动
图 tensorflow、cuda、cudnn对应
图 pytorch 版本对应
cuda-toolkit link
tensorflow_intall
所以我选择了以下的版本安装,即,
conda install pytorch=1.1.0 torchvision=0.3.0 cudatoolkit=9.0 cudnn=7.6.5 tensorflow-gpu=1.14.0 -y
>>> import torch
>>> torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device(type='cuda')
成功在conda下安装 cuda、tensorflow、pytorch。
测试
python main.py --mode test --dataset CelebA --image_size 128 --c_dim 5 \
--sample_dir stargan_celeba/samples --log_dir stargan_celeba/logs \
--model_save_dir stargan_celeba/models --result_dir stargan_celeba/results \
--selected_attrs Black_Hair Blond_Hair Brown_Hair Male Young
图 stargan test
结束。