深度学习之模型压缩三驾马车:基于ResNet18的模型剪枝实战(2)

前言

深度学习之模型压缩三驾马车:基于ResNet18的模型剪枝实战(1)》里面我只是提到了对conv1层进行剪枝,只是为了验证这个剪枝的整个过程,但是后面也有提到:仅裁剪 conv1层的影响极大,原因如下:

  • 底层特征的重要性 : conv1输出的是最基础的图像特征,所有后续层的特征均基于此生成。裁剪 conv1 会直接限制后续所有层的特征表达能力。
  • 结构连锁反应 : conv1的输出通道减少会触发 bn1layer1.0.conv1downsample 等多个模块的调整,任何一个模块的调整失误(如通道数不匹配、参数初始化不当)都会导致整体性能下降。
    虽然,在例子中,我们只是简单的进行了验证,发现效果也不是很差,但是如果具体到自己的数据,或者更加复杂的特征或者模型,可能就会影响到了整体的性能,因此,我们在原有的基础上做了如下的改动:
  1. 剪枝目标层调整 :将 conv1 改为 layer2.0.conv1 ,减少对底层特征的破坏。
  2. 通道评估优化 :通过前向传播收集激活值,优先剪枝激活值低的通道,更符合实际特征贡献。
  3. 微调策略改进 :动态解冻剪枝层及关联的BN、downsample层,学习率降低(0.0001),微调轮次增加(10轮),确保参数充分适应。

这些修改可显著提升剪枝后模型的稳定性和准确率。建议运行时观察微调阶段的Loss是否持续下降,若下降缓慢可进一步降低学习率(如0.00001)。
所有代码都在这:https://gitee.com/NOON47/model_prune

详细改动

  1. 剪枝目标层调整 :将 conv1 改为 layer2.0.conv1 ,减少对底层特征的破坏。
    layer_to_prune = 'layer2.0.conv1'  # 显式定义要剪枝的层名
    pruned_model = prune_conv_layer(model, layer_to_prune, amount=0.2)
  1. 通道评估优化 :通过前向传播收集激活值,优先剪枝激活值低的通道,更符合实际特征贡献。
    model.eval()
    with torch.no_grad():
        test_input = torch.randn(128, 3, 32, 32).to(device)  # 模拟 CIFAR10 输入
        features = []
        def hook_fn(module, input, output):
            features.append(output)
        handle = layer.register_forward_hook(hook_fn)
        model(test_input)
        handle.remove()
        activation = features[0]  # shape: [128, out_channels, H, W]
        channel_importance = activation.mean(dim=(0, 2, 3))  # 按通道求平均激活值

    num_channels = weight.shape[0]
    num_prune = int(num_channels * amount)
    _, indices = torch.topk(channel_importance, k=num_prune, largest=False)
    mask = torch.ones(num_channels, dtype=torch.bool)
    mask[indices] = False  # 生成剪枝掩码
  1. 微调策略改进 :动态解冻剪枝层及关联的BN、downsample层,学习率降低(0.0001),微调轮次增加(10轮),确保参数充分适应。
    print("开始微调剪枝后的模型")
    # 新增:根据剪枝层动态解冻相关层(假设剪枝层为layer2.0.conv1)
    pruned_layer_prefix = layer_to_prune.rpartition('.')[0]  # 例如 'layer2.0'
    for name, param in pruned_model.named_parameters():
        if (pruned_layer_prefix in name) or ('fc' in name) or ('bn' in name):  # 解冻剪枝层、BN层和fc层
            param.requires_grad = True
        else:
            param.requires_grad = False
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, pruned_model.parameters()), lr=0.0001)  # 微调学习率降低
    pruned_model = train_model(pruned_model, train_loader, criterion, optimizer, device, epochs=10)  # 增加微调轮次

完整的裁剪函数:

def prune_conv_layer(model, layer_name, amount=0.2):
    device = next(model.parameters()).device
    layer = dict(model.named_modules())[layer_name]
    weight = layer.weight.data

    # 基于激活值的通道重要性评估
    model.eval()
    with torch.no_grad():
        test_input = torch.randn(128, 3, 32, 32).to(device)  # 模拟 CIFAR10 输入
        features = []
        def hook_fn(module, input, output):
            features.append(output)
        handle = layer.register_forward_hook(hook_fn)
        model(test_input)
        handle.remove()
        activation = features[0]  # shape: [128, out_channels, H, W]
        channel_importance = activation.mean(dim=(0, 2, 3))  # 按通道求平均激活值

    num_channels = weight.shape[0]
    num_prune = int(num_channels * amount)
    _, indices = torch.topk(channel_importance, k=num_prune, largest=False)
    mask = torch.ones(num_channels, dtype=torch.bool)
    mask[indices] = False  # 生成剪枝掩码

    # 创建并替换新卷积层
    new_conv = nn.Conv2d(
        in_channels=layer.in_channels,
        out_channels=num_channels - num_prune,
        kernel_size=layer.kernel_size,
        stride=layer.stride,
        padding=layer.padding,
        bias=layer.bias is not None
    ).to(device)
    new_conv.weight.data = layer.weight.data[mask]  # 应用掩码剪枝权重
    if layer.bias is not None:
        new_conv.bias.data = layer.bias.data[mask]

    # 替换原始卷积层
    parent_name, sep, name = layer_name.rpartition('.')
    parent = model.get_submodule(parent_name)
    setattr(parent, name, new_conv)

    # 仅处理首层 conv1 的特殊逻辑
    if layer_name == 'conv1':
        # 更新首层 BN 层(bn1)
        bn1 = model.bn1
        new_bn1 = nn.BatchNorm2d(new_conv.out_channels).to(device)
        with torch.no_grad():
            new_bn1.weight.data = bn1.weight.data[mask].clone()
            new_bn1.bias.data = bn1.bias.data[mask].clone()
            new_bn1.running_mean.data = bn1.running_mean.data[mask].clone()
            new_bn1.running_var.data = bn1.running_var.data[mask].clone()
        model.bn1 = new_bn1

        # 处理 layer1.0 的 downsample 层(若不存在则创建)
        block = model.layer1[0]
        if not hasattr(block, 'downsample') or block.downsample is None:
            # 创建 1x1 卷积 + BN 用于通道匹配
            downsample_conv = nn.Conv2d(
                in_channels=new_conv.out_channels,
                out_channels=block.conv2.out_channels,  # 与主路径输出通道一致(ResNet18 为 64)
                kernel_size=1,
                stride=1,
                bias=False
            ).to(device)
            # 初始化权重(使用原卷积层的统计量)
            with torch.no_grad():
                downsample_conv.weight.data = layer.weight.data.mean(dim=(2,3), keepdim=True)  # 原卷积核均值初始化
            
            downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels).to(device)
            with torch.no_grad():
                downsample_bn.weight.data.fill_(1.0)
                downsample_bn.bias.data.zero_()
                downsample_bn.running_mean.data.zero_()
                downsample_bn.running_var.data.fill_(1.0)
            
            block.downsample = nn.Sequential(downsample_conv, downsample_bn)
            print("✅ 为 layer1.0 添加新的 downsample 层")
        else:
            # 调整已有 downsample 层的输入通道
            downsample_conv = block.downsample[0]
            downsample_conv.in_channels = new_conv.out_channels
            downsample_conv.weight = nn.Parameter(downsample_conv.weight.data[:, mask, :, :].clone()).to(device)
            # 更新对应的 BN 层
            downsample_bn = block.downsample[1]
            new_downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels).to(device)
            with torch.no_grad():
                new_downsample_bn.weight.data = downsample_bn.weight.data.clone()
                new_downsample_bn.bias.data = downsample_bn.bias.data.clone()
                new_downsample_bn.running_mean.data = downsample_bn.running_mean.data.clone()
                new_downsample_bn.running_var.data = downsample_bn.running_var.data.clone()
            block.downsample[1] = new_downsample_bn

        # 同步 layer1.0.conv1 的输入通道
        target_conv = model.layer1[0].conv1
        if target_conv.in_channels != new_conv.out_channels:
            print(f"同步 layer1.0.conv1 输入通道: {target_conv.in_channels}{new_conv.out_channels}")
            target_conv.in_channels = new_conv.out_channels
            target_conv.weight = nn.Parameter(target_conv.weight.data[:, mask, :, :].clone()).to(device)

    else:
        # 中间层剪枝逻辑(如 layer2.0.conv1)
        block_prefix = layer_name.rsplit('.', 1)[0]  # 提取 block 前缀(如 'layer2.0')
        block = model.get_submodule(block_prefix)     # 获取对应的 block(如 layer2.0)
        
        # 更新当前 block 内的 BN 层(conv1 对应 bn1,conv2 对应 bn2)
        target_bn_name = f"{block_prefix}.bn1" if 'conv1' in layer_name else f"{block_prefix}.bn2"
        try:
            target_bn = model.get_submodule(target_bn_name)
            new_bn = nn.BatchNorm2d(new_conv.out_channels).to(device)
            with torch.no_grad():
                new_bn.weight.data = target_bn.weight.data[mask].clone()
                new_bn.bias.data = target_bn.bias.data[mask].clone()
                new_bn.running_mean.data = target_bn.running_mean.data[mask].clone()
                new_bn.running_var.data = target_bn.running_var.data[mask].clone()
            setattr(block, target_bn_name.split('.')[-1], new_bn)  # 替换原 BN 层
            print(f"✅ 更新剪枝层 {layer_name} 对应的 BN 层 {target_bn_name}")
        except AttributeError:
            print(f"⚠️ 未找到剪枝层 {layer_name} 对应的 BN 层,跳过 BN 更新")

        # 新增:同步后续卷积层的输入通道(如 conv1 后调整 conv2)
        if 'conv1' in layer_name:
            next_conv = block.conv2
            if next_conv.in_channels != new_conv.out_channels:
                print(f"同步 {block_prefix}.conv2 输入通道: {next_conv.in_channels}{new_conv.out_channels}")
                next_conv.in_channels = new_conv.out_channels
                next_conv.weight = nn.Parameter(next_conv.weight.data[:, mask, :, :].clone()).to(device)  # 按剪枝掩码筛选输入通道权重

        # 可选:如果存在 downsample 层,调整其输入通道(根据实际需求启用)
        # if hasattr(block, 'downsample') and block.downsample is not None:
        #     downsample_conv = block.downsample[0]
        #     downsample_conv.in_channels = new_conv.out_channels
        #     downsample_conv.weight = nn.Parameter(downsample_conv.weight.data[:, mask, :, :].clone()).to(device)
        #     print(f"✅ 调整剪枝层 {layer_name} 关联的 downsample 层输入通道")

    # 验证前向传播
    with torch.no_grad():
        test_input = torch.randn(1, 3, 32, 32).to(device)
        try:
            model(test_input)
            print("✅ 前向传播验证通过")
        except Exception as e:
            print(f"❌ 验证失败: {str(e)}")
            raise

    return model

改动后结果

经过改动后, 增加微调轮次,得到的结果如下:

剪枝前模型大小信息:
==========================================================================================
Total params: 11,181,642
Trainable params: 11,181,642
Non-trainable params: 0
Total mult-adds (M): 37.03
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.81
Params size (MB): 44.73
Estimated Total Size (MB): 45.55
==========================================================================================
原始模型准确率: 81.42%

剪枝后模型大小信息:
==========================================================================================
Total params: 11,138,392
Trainable params: 11,138,392
Non-trainable params: 0
Total mult-adds (M): 36.33
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.80
Params size (MB): 44.55
Estimated Total Size (MB): 45.37
==========================================================================================
剪枝后模型准确率: 83.28%

个人认为,这个才是比较符合实际应用的。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值