Torch-Pruning 库入门级使用介绍

在这里插入图片描述

项目地址:https://github.com/VainF/Torch-Pruning

Torch-Pruning 是一个专用于torch的模型剪枝库,其基于DepGraph 技术分析出模型layer中的依赖关系。DepGraph 与现有的修剪方法(如 Magnitude Pruning 或 Taylor Pruning)相结合可以达到良好的剪枝效果。

本博文结合项目官网案例,对信息进行结构话,抽离出剪枝技术说明、剪枝模型保存与加载、剪枝技术的基本使用,剪枝技术的具体使用案例。并结合外部信息,分析剪枝对模型性能精度的影响。

1、基本说明

1.1 项目安装

打开https://github.com/VainF/Torch-Pruning,下载项目
在这里插入图片描述
然后在终端中,进入项目目录,并执行pip install -r requirements.txt 安装项目依赖库
在这里插入图片描述
然后在执行 pip install -e . ,将项目安装在当前目录下,并设置为editing模式。
在这里插入图片描述
验证安装:执行命令python -c "import torch_pruning", 如果没有输出报错信息则表示安装成功。
在这里插入图片描述

1.2 DepGraph 技术说明

在结构修剪中,组被定义为深度网络中最小的可移除单元。每个组由多个相互依赖的层组成,需要同时修剪这些层以保持最终结构的完整性。然而,深度网络通常表现出层与层之间错综复杂的依赖关系,这对结构修剪提出了重大挑战。这项研究通过引入一种名为 DepGraph 的自动化机制来解决这一挑战,该机制可以轻松实现参数分组,并有助于修剪各种深度网络。
在这里插入图片描述

直接剪枝会会破坏layer间的依赖关系,会导致forward流程报错。具体如下面代码,移除model.conv1模块中的idxs为0与1的channel,导致后续的bn1层输入输入与参数格式对不上号,然后报错。

from torchvision.models import resnet18
import torch_pruning as tp
import torch

model = resnet18().eval()
tp.prune_conv_out_channels(model.conv1, idxs=[0,1]) # remove channel 0 and channel 1
output = model(torch.randn(1,3,224,224)) # test

在这里插入图片描述
基本在后续层添加剪枝,运行代码也会保存,因为batchnorm的下一层要求的输出channel是64。

model = resnet18(pretrained=True).eval()
tp.prune_conv_out_channels(model.conv1, idxs=[0,1]) 
tp.prune_batchnorm_out_channels(model.bn1, idxs=[0,1])
tp.prune_batchnorm_in_channels(model.layer1[0].conv1, idxs=[0,1])
output = model(torch.randn(1,3,224,224)) 

使用DepGraph剪枝代码如下,先使用tp.DependencyGraph().build_dependenc构建出依赖图,然后基于DG.get_pruning_group函数获取目标剪枝层的依赖关系组,最后在检验关系并进行剪枝。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()

# 1. build dependency graph for resnet18
DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224))

# 2. Specify the to-be-pruned channels. Here we prune those channels indexed by [2, 6, 9].
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )

# 3. prune all grouped layers that are coupled with model.conv1 (included).
print(group)
if DG.check_pruning_group(group): # avoid full pruning, i.e., channels=0.
    group.prune()
    
# 4. Save & Load
model.zero_grad() # We don't want to store gradient information
torch.save(model, 'model.pth') # without .state_dict
model = torch.load('model.pth') # load the model object

代码执行后的输出如下所示,可以看到捕捉到group对应的依赖layer

--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), idxs=[2, 6, 9] (Pruning Root)
[1] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[2] prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), idxs=[2, 6, 9]
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), idxs=[2, 6, 9]
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp_18(AddBackward0), idxs=[2, 6, 9]
[5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[6] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[7] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on _ElementWiseOp_17(ReluBackward0), idxs=[2, 6, 9]
[8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_out_channels on _ElementWiseOp_16(AddBackward0), idxs=[2, 6, 9]
[9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[10] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[11] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on _ElementWiseOp_15(ReluBackward0), idxs=[2, 6, 9]
[12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), idxs=[2, 6, 9]
[13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
--------------------------------

1.3 剪枝模型的保存与加载

剪枝后的模型由于网络结构改变了,如果只保存模型参数,是无法支持原始网络结构,需要将模型结构连参数一并保存。加载时连同参数一起加载。

model.zero_grad() # We don't want to store gradient information
torch.save(model, 'model.pth') # without .state_dict
model = torch.load('model.pth') # load the pruned model

或者基于tp库中tp.state_dict函数提取目标参数进行保存,并基于tp.load_state_dict函数将剪枝后的参数赋值到原始模型中形成剪枝模型。

# save the pruned state_dict, which includes both pruned parameters and modified attributes
state_dict = tp.state_dict(pruned_model) # the pruned model, e.g., a resnet-18-half
torch.save(state_dict, 'pruned.pth')

# create a new model, e.g. resnet18
new_model = resnet18().eval()

# load the pruned state_dict into the unpruned model.
loaded_state_dict = torch.load('pruned.pth', map_location='cpu')
tp.load_state_dict(new_model, state_dict=loaded_state_dict)
print(new_model) # This will be a pruned model.

2、剪枝基本案例

2.1 具有目标结构的剪枝

以下代码使用TaylorImportance指标进行剪枝,设置忽略输出层的剪枝。并设置MagnitudePruner中对通道剪枝50%,一共分iterative_steps步完成剪枝,每一次剪枝都进行微调。
整体来说,具备目标结构的剪枝,效果是最差的。 基于https://blog.csdn.net/a486259/article/details/140407147 分析的数据得出的结论。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

#model = resnet18(pretrained=True)
model = resnet18()

# Importance criteria
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.TaylorImportance()

ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

iterative_steps = 5 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    #pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    if isinstance(imp, tp.importance.TaylorImportance):
        # Taylor expansion requires gradients for importance estimation
        loss = model(example_inputs).sum() # a dummy loss for TaylorImportance
        loss.backward() # before pruner.step()
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print(f"iter {i} | rate:{macs/base_macs:.4f}  {nparams/base_nparams:.4f}")
print(model)
    # finetune your model here
    # finetune(model)
    # ...

代码的输出信息如下所示,可以看到macs与nparams在逐步降低。最终输出的模型结构,所有的chanel都减半了,只有输出层例外。

iter 0 | rate:0.8092  0.8111
iter 1 | rate:0.6469  0.6445
iter 2 | rate:0.4971  0.4979
iter 3 | rate:0.3718  0.3695
iter 4 | rate:0.2674  0.2614
ResNet(
  (conv1): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=256, out_features=1000, bias=True)
)
PS D:\开源项目\Torch-Pruning-master>
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=256, out_features=1000, bias=True)
)
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=256, out_features=1000, bias=True)
)

2.2 自动结构剪枝

这里的自动结构是有一个预设目标,即将总体channel剪枝到原模型的多少,但没有预定的目标结构。可能有的laye通道剪枝数多,有的剪枝数少。 与2.1中的代码相比,主要是增加了参数 global_pruning=True。但这个剪枝方式比具有目标结构的剪枝更加有效。就像裁员一样,要求各个部门内裁员比例相同与在公司内控制裁员比例(各个部门裁员比例按重要度排列,裁员比例不一样),必然是第二种方式更有效。第一种方式,使低效率部门的靠前但无用员工保留下来了。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

#model = resnet18(pretrained=True)
model = resnet18()

# Importance criteria
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.TaylorImportance()

ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

iterative_steps = 3 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    pruning_ratio=0.5, # remove 50%的channel
    ignored_layers=ignored_layers,
    global_pruning=True
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    if isinstance(imp, tp.importance.TaylorImportance):
        # Taylor expansion requires gradients for importance estimation
        loss = model(example_inputs).sum() # a dummy loss for TaylorImportance
        loss.backward() # before pruner.step()
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print(f"iter {i} | rate:{macs/base_macs:.4f}  {nparams/base_nparams:.4f}")
print(model)
    # finetune your model here
    # finetune(model)
    # ...

2.3 MagnitudePruner中的参数

指定特定层的剪枝比例 通过pruning_ratio_dict参数,指定model.layer2的剪枝比例为20%,这里适用于有先验经验的layer,控制对特定layer的剪枝比例。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)

pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    imp,
    pruning_ratio = 0.5,
    pruning_ratio_dict = {model.layer2: 0.2}
)
pruner.step()
print(model)

代码执行后的层为:ResNet{64, 128, 256, 512} => ResNet{32, 102, 128, 256}

设置最大剪枝比例 通过 max_pruning_ratio 参数设置最大剪枝比例,避免由于稀疏剪枝或者自动剪枝时某个层被严重剪枝或者移除。

剪枝次数与剪枝调度器 您打算分多轮修剪模型,iterative_steps 会很有用。默认情况下,修剪器会逐渐增加模型的稀疏度,直到达到所需的 pruning_ratio。如以下代码,分5次实现剪枝目标。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)

iterative_steps = 5 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
)

# prune the model, iteratively if necessary.
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print("Round %d/%d, Params: %.2f M" % (i+1, iterative_steps, nparams/1e6))
    # finetune your model here
    # finetune(model)
    # ...
print(model)

对应输出如下
Round 1/5, Params: 9.44 M
Round 2/5, Params: 7.45 M
Round 3/5, Params: 5.71 M
Round 4/5, Params: 4.20 M
Round 5/5, Params: 2.93 M

设置忽略的层 这主要是避免对输出层进行剪枝,修改模型的输出结构。使用代码如下,通过ignored_layers参数传入忽略的layer对象。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)

pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.5, # remove 50% channels
    ignored_layers=[model.conv1, model.fc] # ignore the first & last layers
)
pruner.step()
print(model)

channel取整 在很多的时候都认为channel为16的倍数,gpu运行效率最高。使用代码如下,通过round_to参数,保持channel是特定数的倍数。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)

pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.3, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    round_to=10 # round to 10x. Note: 10x is not a good practice.
)

pruner.step()
print(model)

channel_groups 某些层(例如 nn.GroupNorm 和 nn.Conv2d)具有 group 参数,这会在层内引入额外的依赖项。修剪后,保持所有组的大小相同至关重要。为了满足这一要求,引入了参数 channel_groups 以启用对这些通道的手动分组。如以下代码,通过channel_groups参数,控制model.group_conv1中的参数为8个一组

pruner = tp.pruner.MagnitudePruner(
            model,
            example_inputs=example_inputs,
            importance=importance,
            iterative_steps=1,
            pruning_ratio=0.5,
            channel_groups = {model.group_conv1: 8} # For Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), groups=8)
        )

额外参数剪枝 有些时候模型具备的可训练参数并非conv、fc等传统layer中,需要基于unwrapped_parameters参数将额外的可剪枝参数传入到剪枝器中。具体如下所示:

from torchvision.models.convnext import CNBlock, ConvNeXt
unwrapped_parameters = []
for m in model.modules():
    if isinstance(m, CNBlock):
        unwrapped_parameters.append( (m.layer_scale, 0) )

pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.5, 
    unwrapped_parameters=unwrapped_parameters 

限定剪枝范围 root_module_types 参数用于指定组的“根”或第一层。在许多情况下,我们专注于修剪线性层和卷积 (Conv) 层。要专门针对这些层启用修剪,我们可以使用以下参数:root_module_types=[nn.Conv2D, nn.Linear]。这可确保将修剪应用于所需的层。

pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.5, 
    root_module_types=[nn.Conv2D, nn.Linear]

3、具体应用案例

3.1 timm模型剪枝

官方代码为:examples\timm_models\prune_timm_models.py
具体详情如下,这里有一个特殊用法,是通过num_heads参数实现对于transformer layer的支持

import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))))
os.environ['TIMM_FUSED_ATTN'] = '0'
import torch
import torch.nn as nn 
import torch.nn.functional as F
from typing import Sequence
import timm
from timm.models.vision_transformer import Attention
import torch_pruning as tp
import argparse

parser = argparse.ArgumentParser(description='Prune timm models')
parser.add_argument('--model', default=None, type=str, help='model name')
parser.add_argument('--pruning_ratio', default=0.5, type=float, help='channel pruning ratio')
parser.add_argument('--global_pruning', default=False, action='store_true', help='global pruning')
parser.add_argument('--pretrained', default=False, action='store_true', help='global pruning')
parser.add_argument('--list_models', default=False, action='store_true', help='list all models in timm')
args = parser.parse_args()

def main():
    timm_models = timm.list_models()
    if args.list_models:
        print(timm_models)
    if args.model is None: 
        return
    assert args.model in timm_models, "Model %s is not in timm model list: %s"%(args.model, timm_models)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = timm.create_model(args.model, pretrained=args.pretrained, no_jit=True).eval().to(device)

    imp = tp.importance.GroupNormImportance()
    print("Pruning %s..."%args.model)
        
    input_size = model.default_cfg['input_size']
    example_inputs = torch.randn(1, *input_size).to(device)
    test_output = model(example_inputs)
    ignored_layers = []
    num_heads = {}

    for m in model.modules():
        if hasattr(m, 'head'): #isinstance(m, nn.Linear) and m.out_features == model.num_classes:
            ignored_layers.append(model.head)
            print("Ignore classifier layer: ", m.head)
       
        # Attention layers
        if hasattr(m, 'num_heads'):
            if hasattr(m, 'qkv'):
                num_heads[m.qkv] = m.num_heads
                print("Attention layer: ", m.qkv, m.num_heads)
            elif hasattr(m, 'qkv_proj'):
                num_heads[m.qkv_proj] = m.num_heads

    print("========Before pruning========")
    print(model)
    base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)
    pruner = tp.pruner.MetaPruner(
                    model, 
                    example_inputs, 
                    global_pruning=args.global_pruning, # If False, a uniform pruning ratio will be assigned to different layers.
                    importance=imp, # importance criterion for parameter selection
                    iterative_steps=1, # the number of iterations to achieve target pruning ratio
                    pruning_ratio=args.pruning_ratio, # target pruning ratio
                    num_heads=num_heads,
                    ignored_layers=ignored_layers,
                )
    for g in pruner.step(interactive=True):
        g.prune()

    for m in model.modules():
        # Attention layers
        if hasattr(m, 'num_heads'):
            if hasattr(m, 'qkv'):
                m.num_heads = num_heads[m.qkv]
                m.head_dim = m.qkv.out_features // (3 * m.num_heads)
            elif hasattr(m, 'qkv_proj'):
                m.num_heads = num_heads[m.qqkv_projkv]
                m.head_dim = m.qkv_proj.out_features // (3 * m.num_heads)

    print("========After pruning========")
    print(model)
    test_output = model(example_inputs)
    pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
    print("MACs: %.4f G => %.4f G"%(base_macs/1e9, pruned_macs/1e9))
    print("Params: %.4f M => %.4f M"%(base_params/1e6, pruned_params/1e6))

if __name__=='__main__':
    main()

3.2 llm模型剪枝

在examples\LLMs\prune_llama.py中提供了一个对于llama模型的剪枝案例.
核心代码如下,可以看到也是基于num_heads记录transformer的结构信息,然后在剪枝后将num_heads数据赋值到对应模型参数上。与原始代码相比,这里删除了模型精度验证相关的代码。


# Code adapted from 
# https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py
# https://github.com/locuslab/wanda

import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))))

import argparse
import os 
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from importlib.metadata import version
import time
import torch
import torch.nn as nn
from collections import defaultdict
import fnmatch
import numpy as np
import random

print('torch', version('torch'))
print('transformers', version('transformers'))
print('accelerate', version('accelerate'))
print('# of gpus: ', torch.cuda.device_count())

def get_llm(model_name, cache_dir="./cache"):
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch.float16, 
        cache_dir=cache_dir, 
        device_map="auto"
    )

    model.seqlen = model.config.max_position_embeddings 
    return model

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, help='LLaMA model')
    parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
    parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.')
    parser.add_argument('--pruning_ratio', type=float, default=0, help='Sparsity level')
    parser.add_argument("--cache_dir", default="./cache", type=str )
    parser.add_argument('--save', type=str, default=None, help='Path to save results.')
    parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.')
    parser.add_argument("--eval_zero_shot", action="store_true")
    args = parser.parse_args()

    # Setting seeds for reproducibility
    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)

    model_name = args.model.split("/")[-1]
    print(f"loading llm model {args.model}")
    model = get_llm(args.model, args.cache_dir)       
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
    device = torch.device("cuda:0")
    if "30b" in args.model or "65b" in args.model: # for 30b and 65b we use device_map to load onto multiple A6000 GPUs, thus the processing here.
        device = model.hf_device_map["lm_head"]
    print("use device ", device)

    ##############
    # Pruning
    ##############
    print("----------------- Before Pruning -----------------")
    print(model)
    text = "Hello world."
    inputs = torch.tensor(tokenizer.encode(text)).unsqueeze(0).to(model.device)
    import torch_pruning as tp 
    num_heads = {}
    for name, m in model.named_modules():
        if name.endswith("self_attn"):
            num_heads[m.q_proj] = model.config.num_attention_heads
            num_heads[m.k_proj] = model.config.num_key_value_heads
            num_heads[m.v_proj] = model.config.num_key_value_heads
            
    head_pruning_ratio = args.pruning_ratio
    hidden_size_pruning_ratio = args.pruning_ratio
    pruner = tp.pruner.MagnitudePruner(
        model, 
        example_inputs=inputs,
        importance=tp.importance.GroupNormImportance(),
        global_pruning=False,
        pruning_ratio=hidden_size_pruning_ratio,
        ignored_layers=[model.lm_head],
        num_heads=num_heads,
        prune_num_heads=True,
        prune_head_dims=False,
        head_pruning_ratio=head_pruning_ratio,
    )
    pruner.step()

    # Update model attributes
    num_heads = int( (1-head_pruning_ratio) * model.config.num_attention_heads )
    num_key_value_heads = int( (1-head_pruning_ratio) * model.config.num_key_value_heads )
    model.config.num_attention_heads = num_heads
    model.config.num_key_value_heads = num_key_value_heads
    for name, m in model.named_modules():
        if name.endswith("self_attn"):
            m.hidden_size = m.q_proj.out_features
            m.num_heads = num_heads
            m.num_key_value_heads = num_key_value_heads
        elif name.endswith("mlp"):
            model.config.intermediate_size = m.gate_proj.out_features
    print("----------------- After Pruning -----------------")
    print(model)

    #ppl_test = eval_ppl(args, model, tokenizer, device)
    #print(f"wikitext perplexity {ppl_test}")

    if args.save_model:
        model.save_pretrained(args.save_model)
        tokenizer.save_pretrained(args.save_model)

if __name__ == '__main__':
    main()

3.3 目标检测模型剪枝

在Torch-Pruning 库中提供了针对yolov8、yolov7、yolov5的剪枝案例。关于yolov8还提供了剪枝后的训练策略,其主要技巧在与对不可剪枝层的可剪枝话处理(C2f模块的剪枝,其含split操作,不利于剪枝索引)。后续会补充博客,说明对yolov8的剪枝使用。

4、其他信息

4.1 剪枝器中的评价指标

在torch_pruning\pruner\importance.py中有很多个剪枝评价指标

__all__ = [
    # Base Class
    "Importance",

    # Basic Group Importance
    "GroupNormImportance",
    "GroupTaylorImportance",
    "GroupHessianImportance",

    # Aliases
    "MagnitudeImportance",
    "TaylorImportance",
    "HessianImportance",

    # Other Importance
    "BNScaleImportance",
    "LAMPImportance",
    "RandomImportance",
]

整体来看是TaylorImportance最好,一直使用该值即可。
来看

4.2 剪枝对性能精度的影响

在博客https://blog.csdn.net/a486259/article/details/140407147?spm=1001.2014.3001.5501 中基本确定了剪枝50%,对模型精度是没有任何影响的。这里对Torch-Pruning 库相关的论文数据进行二次核验,以致于分析出剪枝中速度提升对精度的影响。

以DepGraph: Towards Any Structural Pruning数据为例,可以发现最高支持6x速度剪枝后保持模型性能。
在这里插入图片描述
以LLM-Pruner: On the Structural Pruning of Large Language Models 论文数据为例,可以发现使用Vector评价方法的剪枝,移除10%的参数,zero-shot下对模型精度影响不大。而图4更表明,剪枝方法正确的话,移除50%的参数对模型性能影响也不大。
在这里插入图片描述
以论文 Structural Pruning for Diffusion Models 的数据为分析,同样可以发现剪枝50%左右的通道,对结果影响不对。
在这里插入图片描述

  • 25
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
### 回答1: 很高兴能回答您的问题,torch-scatter是一个PyTorch,它提供了一系列函数来处理数据中的索引和深度学习中的稀疏数据。它可以用于计算距离,更新索引,扩展索引,处理数据块,处理稀疏数据,进行数据平衡,调整大小,等等。 ### 回答2: torch-scatter是一个用于在PyTorch中进行稀疏张量的聚合操作的。在深度学习中,稀疏张量是一种特殊的数据结构,其中大部分元素都是零。与普通的密集张量相比,稀疏张量在存储和计算上更加高效。而torch-scatter提供了一些方便的函数,可以针对稀疏张量进行聚合操作。 具体来说,torch-scatter提供了一些常用的聚合函数,如求和、均值、最大值等。这些函数能够对稀疏张量进行聚合计算,并返回聚合结果。与传统的dense tensor相比,对于稀疏张量的聚合计算,torch-scatter能够更高效地处理,节省内存和计算资源。 该还提供了一些高级功能,如自定义聚合函数和处理多个输入张量的聚合等。这些功能使得用户可以根据自己的需求,定义并执行复杂的稀疏张量聚合计算。 总的来说,torch-scatter为PyTorch用户提供了一种便捷而高效的处理稀疏张量的方式,使得稀疏张量的聚合计算更加方便和灵活。同时,该也为深度学习领域中以稀疏张量为基础的算法研究提供了很好的支持。 ### 回答3: torch-scatter是PyTorch中的一个扩展,主要用于执行图数据的分散(scatter)操作。图数据是指由节点和边构成的复杂数据结构,通常用于表示非结构化数据,如社交网络、知识图谱等。 torch-scatter通过提供一系列高效的图聚合操作,使得在图数据上进行计算更加方便和高效。其中最常用的操作是scatter_add函数,它允许在图节点上对特征进行聚合,生成全局的节点特征表示。 具体来说,torch-scatter可以执行以下操作: 1. scatter_add: 将每个节点的特征按照图边的连接关系进行聚合,并返回聚合结果。这对于实现图卷积网络(GCN)等图神经网络模型非常关键。 2. scatter_mean: 类似于scatter_add,但是将节点的特征聚合为均值。 3. scatter_max: 类似于scatter_add,但是将节点的特征聚合为最大值。 4. scatter_min: 类似于scatter_add,但是将节点的特征聚合为最小值。 5. scatter_mul: 类似于scatter_add,但是将节点的特征进行乘法聚合。 除了上述操作外,torch-scatter还提供了一些其他的辅助函数,如index_select、index_add等,用于快速和灵活地处理图数据。 总之,torch-scatter是PyTorch中一个强大的图聚合操作,提供了高效的图数据处理方法,方便用户在图神经网络模型中进行计算和研究。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

万里鹏程转瞬至

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值