网络剪枝——network-slimming 项目复现

目录

网络剪枝——network-slimming 项目复现

  • 【GiHnub】:Eric-mingjie/network-slimming: Network Slimming (Pytorch) (ICCV 2017) (github.com)
  • 【作者复现项目】:
  • 通过百度网盘分享的文件:network-slimming-regin.zip
    链接:https://pan.baidu.com/s/1vTJSLS5ZDjE8R8XaApW96A?pwd=t1z2
    提取码:t1z2
    • 仅以 CIFAR-10 为例,CIFAR-100 同理.
    • 提供中文README_zh-CN.md.
    • 包含 CIFAR-10/100 数据集data.cifar10data.cifar100.
    • 解决了 main.py 运行报错问题.
    • 加入了计算训练后模型的 Parameters 大小脚本param_counter.py.

clone 存储库

注:若 clone 作者复现项目,则忽略这一步,直接进入下一步;若想自行从头复现,则 clone 以下存储库.

  • 链接:https://pan.baidu.com/s/1nppPLKoiPbJPW60HOa2TxQ?pwd=ud89
    提取码:ud89


Baseline

vgg

训练
  • 【命令】:
python main.py --dataset cifar10 --arch vgg --depth 19

  • 这个报错通常出现在使用 Python 的multiprocessing库来创建进程时,尤其是在 Windows 操作系统上. 在 Windows 上,Python 的multiprocessing模块启动新进程的方式与 Linux 或 macOS 不同,它使用 “spawn” 来启动新进程,这意味着每个子进程都会从头开始执行脚本. 因此,如果在脚本顶层级别启动进程(而不是在受保护的if __name__ == '__main__':块中),每个子进程都会尝试再次启动子进程,从而导致无限递归和上述错误.
  • 为了解决这个问题,应 确保多进程代码(即main.py)位于if __name__ == '__main__':保护块内.
# 导入部分
...

def main():
    ...


if __name__ == '__main__':
    main()
  • 再次运行命令,又报错:

  • 这个报错通常发生在尝试直接索引一个0维的张量(tensor)时. 在 PyTorch 中,0 维张量是一个单一值的张量,但是不能像普通的数组那样通过索引来访问。要从 0 维张量中获取其 Python 数值,需要使用.item()方法.
  • 为了解决这个问题,应该 使用.item()方法来替换所有.data[0]的用法
# 在 train 函数中
if batch_idx % args.log_interval == 0:
    print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
               100. * batch_idx / len(train_loader), loss.item()))

# 在 test 函数中
for data, target in test_loader:
    if args.cuda:
        data, target = data.cuda(), target.cuda()
    data, target = Variable(data), Variable(target)
    output = model(data)
    test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
    pred = output.data.max(1, keepdim=True)[1]
    correct += pred.eq(target.data.view_as(pred)).cpu().sum()

test_loss /= len(test_loader.dataset)
  • 再次运行命令就正常运行了:

结果
  • Terminal

  • 在 ./logs 生成文件checkpoint.pth.tarmodel_best.pth.tar

resnet

训练
  • 【命令】:
python main.py --dataset cifar10 --arch resnet --depth 164
结果

densenet

训练
  • 【命令】:
python main.py --dataset cifar10 --arch densenet --depth 40
结果


Sparsity

vgg

训练
  • 【命令】:
python main.py -sr --s 0.0001 --dataset cifar10 --arch vgg --depth 19
结果

resnet

训练
  • 【命令】:
python main.py -sr --s 0.00001 --dataset cifar10 --arch resnet --depth 164
结果

densenet

训练
  • 【命令】:
python main.py -sr --s 0.00001 --dataset cifar10 --arch densenet --depth 40
结果


Prune

vgg

命令
python vggprune.py --dataset cifar10 --depth 19 --percent 0.7 --model ./results/CIFAR10_results/CIFAR10-Vgg/Sparsity/model_best.pth.tar --save ./prunes

  • main.py同理,为了解决这个问题,应 确保多进程代码位于if __name__ == '__main__':保护块内
# 导入部分
...

def main():
    ...


if __name__ == '__main__':
    main()
  • 之后就可以正常运行了.

结果
  • Terminal

  • 在./prunes生成文件prune.txtpruned.pth.tar

  • prune.txt中我们可以看到 Number of parametersTest accuracy

resnet

命令
python resprune.py --dataset cifar10 --depth 164 --percent 0.4 --model ./results/CIFAR10_results/CIFAR10-Resnet-164/Sparsity/model_best.pth.tar --save ./prunes
结果

densenet

命令
python denseprune.py --dataset cifar10 --depth 40 --percent 0.4 --model ./results/CIFAR10_results/CIFAR10-Densenet-40/Sparsity/model_best.pth.tar --save ./prunes
结果


Fine-tune

vgg

训练
  • 【命令】:
python main.py --refine ./results/CIFAR10_results/CIFAR10-Vgg/Prune/pruned.pth.tar --dataset cifar10 --arch vgg --depth 19 --epochs 160
结果

resnet

训练
  • 【命令】:
python main.py --refine ./results/CIFAR10_results/CIFAR10-Resnet-164/Prune/pruned.pth.tar --dataset cifar10 --arch resnet --depth 164 --epochs 160
结果

densenet

训练
  • 【命令】:
python main.py --refine ./results/CIFAR10_results/CIFAR10-Densenet-40/Prune/pruned.pth.tar --dataset cifar10 --arch densenet --depth 40 --epochs 160
结果


模型大小计算脚本 param_counter.py

  • 【路径】:./script/param_counter.py
import torch


def load_model(model_path):
    model = torch.load(model_path, map_location=torch.device('cpu'))
    return model


def count_parameters(model_state_dict):
    total_params = sum(p.numel() for p in model_state_dict.values())
    return total_params


def get_model_parameters(model_path):
    # 加载模型状态字典
    model = load_model(model_path)

    # 模型状态字典存储在 'state_dict' 键下
    model_state_dict = model['state_dict'] if 'state_dict' in model else model

    # 计算参数总数
    total_params = count_parameters(model_state_dict)
    return total_params
  • main.py中:
from script.param_counter import get_model_parameters

def main():
    ...
    # 计算 Parameters
    model_path = 'logs/model_best.pth.tar'
    total_params = get_model_parameters(model_path)
    print(f'Total parameters in the model: {total_params}')

结果汇总

注:与原项目结果略有差别.

CIFAR10

CIFAR10-VggBaselineSparsity(1e-4)Prune(70%)Fine-tune-160(70%)
Top1 Accuracy(%)93.7293.6033.9893.75
Parameters20.05M20.05M2.22M2.23M
CIFAR10-Resnet-164BaselineSparsity(1e-5)Prune(40%)Fine-tune-160(40%)
Top1 Accuracy(%)94.9995.0094.5995.27
Parameters1.74M1.74M1.46M1.49M
CIFAR10-Densenet-40BaselineSparsity(1e-5)Prune(40%)Fine-tune-160(40%)
Top1 Accuracy(%)94.1594.3794.1494.48
Parameters1.09M1.09M0.70M0.72M
  • 19
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
回溯法是解决0-1背包问题的一种常用方法。该问题是指在给定n种物品和一个容量为C的背包的情况下,如何选择装入背包的物品,使得装入背包中物品的总价值最大。回溯法的基本思路是搜索所有可能的解,并在搜索过程中剪枝,以达到减少搜索次数的目的。具体实现可以参考引用中的递归函数rKnap。 在回溯法中,我们首先将物品按照单位重量的价值递减排序,然后从第一个物品开始搜索。对于每个物品,我们有两种选择:将其放入背包或不放入背包。如果将其放入背包,我们需要检查当前背包容量是否足够,如果足够,则将其放入背包,并更新当前背包的重量和价值。然后递归搜索下一个物品。如果不将其放入背包,则直接递归搜索下一个物品。在搜索过程中,我们需要记录当前背包的重量和价值,以及当前最优解的最大价值。如果当前背包的价值已经超过当前最优解的最大价值,则可以剪枝,不再继续搜索。 C++代码实现可以参考以下范例: <<范例: #include <iostream> #include <algorithm> using namespace std; const int MAXN = 100; int n, c; int w[MAXN], v[MAXN]; int bestv = 0, curv = 0, curw = 0; void backtrack(int i) { if (i > n) { bestv = max(bestv, curv); return; } if (curw + w[i] <= c) { curw += w[i]; curv += v[i]; backtrack(i + 1); curw -= w[i]; curv -= v[i]; } if (curv + v[i] * (c - curw) > bestv) { backtrack(i + 1); } } int main() { cin >> n >> c; for (int i = 1; i <= n; i++) { cin >> w[i] >> v[i]; } sort(w + 1, w + n + 1); sort(v + 1, v + n + 1); backtrack(1); cout << bestv << endl; return 0; } >>
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值