环境要求
pytorch 0.1
0.2版本会有问题,解决方法后文会介绍。
如何查看版本?在终端输入:
python
import torch
torch.version
数据集制作
本文代码使用的数据加载方法是datasets.ImageFolder,它要求数据集不同类别的图片放在不同文件夹下,文件格式如下:
将自己的数据集做成图中的形式,即可。
运行过程
train
python finetune.py –train –train_path path_to_your_dataset/train/ –test_path path_to_your_dataset/test/
这里会去掉VGG的后三层fc,而根据自己的num_class训练新的fc层,总共迭代20个epoch。
prune
python finetune.py –prune –train_path path_to_your_dataset/train/ –test_path path_to_your_dataset/test/
这里会开始对filter进行剪枝,每次剪512个filter,每次剪完都会迭代10个epoch以恢复模型的能力;然后继续下一次剪枝,直至将VGG模型的2/3的filter剪掉;最后进行15个epoch以得到最终的剪枝模型。
ps:这里的512是自己设定的,论文中是1,但每次只剪1个filter太慢,所以设定为512,加快剪枝过程。