数据预处理
将cifar10数据处理成两个cifar5,其中label为{0,1,2,3,4}的五类数据放置成一组,命名为first_cifar_5,label为{5,6,7,8,9}的五类数据放置成一组,命名为last_cifar_5.
数据分享在google drive/shared drive/BN-Quant/prune里。
模型训练
模型有fc,denseDNN,resnet等,本次实验以fc为模型进行。
模型训练状态分为freeze和non-freeze两个状态。
在non-freeze状态,模型首先会保存下初始状态,然后进行ITREATION=10次迭代。每次迭代,模型会训练end_iter = 20个epoch,在这100个epoch里面不会修改learning_rate。根据paper,learing_rate过大时,模型收敛的不好,所以源码中固定以learing_rate=1e0-3进行。 每训练完20个eopch,模型会进行一次剪枝,剪枝比例为当前alive 参数的10%。然后把剪枝的位置保存下来。
在freeze状态,模型会首先载入之前保存的初始状态。随后,模型加载剪枝比例最高的剪枝位置数据。最后,模型会进行end_iter = 20个epoch的训练。
使用方式
non-freeze 状态
!python /content/Lottery-Ticket-Hypothesis-in-Pytorch/main.py --end_iter 20 --dataset first_cifar_5 --arch_type fc1 --prune_iterations 10
freeze状态
!python /content/Lottery-Ticket-Hypothesis-in-Pytorch/main.py --end_iter 20 --dataset last_cifar_5 --arch_type fc1 --prune_iterations 10 --mask_path '/content/dumps/lt/fc1/first_cifar_5/lt_mask_81.0.pkl' --initial_weight_path '/content/saves/fc1/first_cifar_5/initial_state_dict_lt.pth.tar' --freeze True
参数说明
参数 | 说明 |
---|---|
freeze | 表示是freeze模式还是non-freeze模式 |
end_iter | 每次训练的epoch数量 |
dataset | 数据集。在non-freeze状态下,需要使用first_cifar_5,在freeze状态下,需要使用last_cifar_5。 |
arch_type | 选用的模型。本次实验选用了fc1.只有两个全连接层。 |
prune_iterations | 剪枝的次数。 |
mask_path | mask存储了剪枝的位置,这里是保存mask的路径,需要填写最后一次剪枝保存的mask数据路径。 |
initial_weight_path | 存储模型初始值的路径 |
实验结果
1:将last_cifar_5结果迁移到first_cifar_5上
prune level | 直接在firsrt_cifar_5进行训练的准确率 | 在last_cifar_5上训练,迁移到first_cifar_5上的准确率 |
---|---|---|
0% | 60.88% | |
10% | 62.06% | 62.82% |
18.99% | 62.04% | |
27.09% | 62.44% | |
34.38% | 62.52% | 61.66% |
40.93% | 62.34% | 61.54% |
46.84% | 62.74% | |
52.15% | 61.50% | 59.50% |
56.93% | 60.84% | |
61.23% | 59.18% | 57.96% |
1:将first_cifar_5结果迁移到last_cifar_5上
prune level | 直接在last_cifar_5进行训练的准确率 | 在fist_cifar_5上训练,迁移到last_cifar_5上的准确率 |
---|---|---|
0% | 67.92% | |
10% | 69.52% | 68.10% |
18.99% | 68.96% | |
27.09% | 68.92% | |
34.38% | 69.02% | 68.10% |
40.93% | 68.52% | 68.34% |
46.84% | 67.84% | |
52.15% | 67.26% | 68.32% |
56.93% | 66.74% | |
61.23% | 66.20% | 67.72% |