终于写到最后一部分了,前面读完了paper和代码,现在看下如何复现这个工作:
想了解这个工作的可以看下前面的paper和介绍:原paperarxiv.org科技猛兽:GAN Compression1:原理分析zhuanlan.zhihu.com科技猛兽:GAN Compression 2:代码解读zhuanlan.zhihu.com
官方代码:https://github.com/mit-han-lab/gan-compression/blob/master/README.mdgithub.com
所需环境:Linux
Python 3
CPU or NVIDIA GPU + CUDA CuDNN
接下来正式开始!
首先把代码clone到本地:
git clone git@github.com:mit-han-lab/gan-compression.git
cdgan-compression
1.装一个PyTorch1.4,以及一些依赖库(torchvision):
细心的你可能发现代码里面有个requirement.txt,里面是这样的:
absl-py==0.9.0
blessings==1.7
certifi==2019.11.28
dominate==2.4.0
grpcio==1.16.1
Markdown==3.1.1
numpy==1.18.1
nvidia-ml-py3==7.352.0
olefile==0.46
opencv-python==4.2.0.32
Pillow==7.0.0
protobuf==3.11.3
psutil==5.7.0
scipy==1.4.1
six==1.14.0
tensorboard==2.0.0
tensorboardX==2.0
torch==1.4.0
torchvision==0.5.0
torchprofile==0.0.1
tqdm==4.42.1
Werkzeug==1.0.0
wget==3.2
不要怀疑,你没有眼花,这些都是需要先装好的~
如果你是pip选手,那就:
pip install -r requirements.txt
如果你是conda选手,请打开scripts文件夹,里面有一个可执行文件conda_deps.sh,你就:
scripts/conda_deps.sh
那这个conda_deps.sh是个什么玩意咧?
打开后发现是:
#!/usr/bin/env bash
set -ex
conda install pytorch==1.4.0 torchvision==0.5.0 -c pytorch
conda install tqdm scipy tensorboard
conda install -c conda-forge tensorboardx
pip install opencv-python dominate wget
还是这一堆库。
这里建议直接conda创建一个新的虚拟环境,专门用来跑这个实验,创建的方法是:
conda create -n your_env_name python=X.X(2.7、3.6等)
查看你都有哪些虚拟环境:
conda info -e
查看安装了哪些包:
conda list
接下来还要安装torchprofile:https://github.com/zhijian-liu/torchprofilegithub.com
pip install --upgrade git+https://github.com/mit-han-lab/torchprofile.git
这个应该是计算macs的库。
2.到这里所依赖的库都装完了,现在该准备数据集了并尝试作者的预训练模型:
2.1 CycleGAN
Download the CycleGAN dataset (e.g., horse2zebra):
bash datasets/download_cyclegan_dataset.sh horse2zebra
获取ground-truth image的统计信息,以计算FID值(CycleGAN dataset使用FID作为评价指标)。
bash datasets/download_real_stat.sh horse2zebra A
bash datasets/download_real_stat.sh horse2zebra B
在训练之前,我们可以先试试作者给的Pre-trained model,看看效果如何。
首先下载Pre-trained model(原模型和压缩之后的模型):
python scripts/download_model.py --model cyclegan --task horse2zebra --stage full
python scripts/download_model.py --model cyclegan --task horse2zebra --stage compressed
测试一下没压缩过的大模型:
bash scripts/cycle_gan/horse2zebra/test_full.sh
测试一下压缩后的模型:
bash scripts/cycle_gan/horse2zebra/test_compressed.sh
看一下这个执行文件的内容:
#!/usr/bin/env bash
python test.py --dataroot database/horse2zebra/valA \
--dataset_mode single \
--results_dir results-pretrained/cycle_gan/horse2zebra/compressed \
--config_str 16_16_32_16_32_32_16