基于强化学习的自动化裁剪CIFAR-10 分类任务(提升模型精度+减少计算量)

68 篇文章 2 订阅
32 篇文章 0 订阅

基于强化学习的自动化裁剪,提升模型精度的同时减少计算量。

介绍

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-RFnHlyQG-1691544546106)(./pic/APT-main.png)]

目前的强化学习工作很多集中在利用外部环境的反馈训练agent,忽略了模型本身就是一种能够获得反馈的环境。本项目的核心思想是:将模型视为环境,构建附生于模型的 agent ,以辅助模型进一步拟合真实样本。

大多数领域的模型都可以采用这种方式来优化,如cv\多模态等。它至少能够以三种方式工作:1.过滤噪音信息,如删减语音或图像特征;2.进一步丰富表征信息,如高效引用外部信息;3.实现记忆、联想、推理等复杂工作,如构建重要信息的记忆池。

这里推出一款早期完成的裁剪机制transformer版本(后面称为APT),实现了一种更高效的训练模式,能够优化模型指标;此外,可以使用动态图丢弃大量的不必要单元,在指标基本不变的情况下,大幅降低计算量。

该项目希望为大家抛砖引玉。

在这里插入图片描述

为什么要做自动剪枝

在具体任务中,往往存在大量毫无价值的信息和过渡性信息,有时不但对任务无益,还会成为噪声。比如:表述会存在冗余/无关片段以及过渡性信息;动物图像识别中,有时候背景无益于辨别动物主体,即使是动物部分图像,也仅有小部分是关键的特征。

在这里插入图片描述

以transformer为例,在进行self-attention计算时其复杂度与序列长度平方成正比。长度为10,复杂度为100;长度为9,复杂度为81。

利用强化学习构建agent,能够精准且自动化地动态裁剪已丧失意义部分,甚至能将长序列信息压缩到50-100之内(实验中有从500+的序列长度压缩到个位数的示例),以大幅减少计算量。

实验中,发现与裁剪agent联合训练的模型比普通方法训练的模型效果要更好。

使用说明

环境

torch
numpy
tqdm
tensorboard
ml-collections

下载经过预先​​训练的模型(来自Google官方)

本项目使用的型号:ViT-B_16(您也可以选择其它型号进行测试)

# imagenet21k pre-train
wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz

训练与推理

下载好预训练模型就可以跑了。

# 训练
python3 train.py --name cifar10-100_500 --dataset cifar100 --model_type ViT-B_16 --pretrained_dir checkpoint/ViT-B_16.npz

# 推理
python3 infer.py --name cifar10-100_500 --dataset cifar100 --model_type ViT-B_16 --pretrained_dir checkpoint/ViT-B_16.npz

CIFAR-10和CIFAR-100会自动下载和培训。如果使用其他数据集,您需要自定义data_utils.py。

在裁剪模式的推理过程中,预期您将看到如下格式的输出。

Validating... (loss=0.13492):   1%|| 60/10000 [00:01<02:36, 63.34it/s]
初始输入形状::: torch.Size([1, 197, 768])
第2层形状::: torch.Size([1, 196, 768])
第5层形状::: torch.Size([1, 188, 768])
第8层形状::: torch.Size([1, 186, 768])
Validating... (loss=0.01283):   1%|| 60/10000 [00:01<02:36, 63.34it/s]
初始输入形状::: torch.Size([1, 197, 768])
第2层形状::: torch.Size([1, 183, 768])
第5层形状::: torch.Size([1, 166, 768])
第8层形状::: torch.Size([1, 166, 768])
Validating... (loss=3.71401):   1%|| 60/10000 [00:01<02:36, 63.34it/s]
初始输入形状::: torch.Size([1, 197, 768])
第2层形状::: torch.Size([1, 193, 768])
第5层形状::: torch.Size([1, 191, 768])
第8层形状::: torch.Size([1, 186, 768])
Validating... (loss=0.00328):   1%|| 67/10000 [00:01<02:35, 63.93it/s]
初始输入形状::: torch.Size([1, 197, 768])
第2层形状::: torch.Size([1, 191, 768])
第5层形状::: torch.Size([1, 164, 768])
第8层形状::: torch.Size([1, 123, 768])
Validating... (loss=0.03190):   1%|| 67/10000 [00:01<02:35, 63.93it/s]
初始输入形状::: torch.Size([1, 197, 768])
第2层形状::: torch.Size([1, 193, 768])
第5层形状::: torch.Size([1, 187, 768])
第8层形状::: torch.Size([1, 160, 768])
Validating... (loss=0.00356):   1%|| 67/10000 [00:01<02:35, 63.93it/s]
初始输入形状::: torch.Size([1, 197, 768])
第2层形状::: torch.Size([1, 193, 768])
第5层形状::: torch.Size([1, 187, 768])
第8层形状::: torch.Size([1, 182, 768])
Validating... (loss=0.00297):   1%|| 67/10000 [00:01<02:35, 63.93it/s]
初始输入形状::: torch.Size([1, 197, 768])
第2层形状::: torch.Size([1, 197, 768])
第5层形状::: torch.Size([1, 167, 768])
第8层形状::: torch.Size([1, 162, 768])
Validating... (loss=0.00162):   1%|| 67/10000 [00:01<02:35, 63.93it/s]
初始输入形状::: torch.Size([1, 197, 768])
第2层形状::: torch.Size([1, 189, 768])
第5层形状::: torch.Size([1, 179, 768])
第8层形状::: torch.Size([1, 157, 768])
Validating... (loss=0.08821):   1%|| 67/10000 [00:01<02:35, 63.93it/s]
初始输入形状::: torch.Size([1, 197, 768])
第2层形状::: torch.Size([1, 197, 768])
第5层形状::: torch.Size([1, 174, 768])
第8层形状::: torch.Size([1, 156, 768])

默认的batch size为72、gradient_accumulation_steps为3。当GPU内存不足时,您可以通过它们来进行训练。

注:相较于原始的ViT,APT(Automatic pruning transformer)的训练步数、训练耗时都会上升。原因是使用pruning agent的模型由于总会丢失部分信息,使得收敛速度变慢,同时为了训练pruning agent,也需要多次的观测、行动、反馈。

模型介绍

模型主体

基于transformer的视觉预训练模型ViT是本项目的模型主体,具体细节可以查看论文:《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》
在这里插入图片描述

自动化裁剪的智能体

对于强化学习agent来说,最关键的问题之一是如何衡量动作带来的反馈。为了评估单次动作所带来的影响,使用了以下三步骤:

1.使用一个普通模型(无裁剪模块)进行预测;

2.使用一个带裁剪器的模型(执行一次裁剪动作)进行预测;

3.对比两次预测的结果,若裁剪后损失相对更小,则说明该裁剪动作帮助了模型进一步拟合真实状况,应该得到奖励;反之,应该受到惩罚。

但是在实际预测过程中,模型是同时裁剪多个单元的,这或将因为多个裁剪的连锁反应而导致模型失效。训练过程中需要构建一个带裁剪器的模型(可执行多次裁剪动作),以减小该问题所带来的影响。

综上,本模型使用的是三通道模式进行训练。

在这里插入图片描述

关于裁剪器的模型结构设计,本模型中认为如何衡量一个信息单元是否对模型有意义,建立于其自身的信息及它与任务的相关性上。

因此以信息单元本身及它与CLS单元的交互作为agent的输入信息。

在这里插入图片描述

实验


数据集ViTAPT(pruning)APT(no pruning)
CIFAR-10092.392.693.03
CIFAR-1099.0898.9398.92

以上加载的均为ViT-B_16,resolution为224*224。

致谢

感谢基于pytorch的图像分类项目,本项目是在此基础上做的研发。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
强化学习(Reinforcement Learning, RL),又称再励学习、评价学习或增强学习,是机器学习的范式和方法论之一。它主要用于描述和解决智能体(agent)在与环境的交互过程中通过学习策略以达成回报最大化或实现特定目标的问题。强化学习的特点在于没有监督数据,只有奖励信号。 强化学习的常见模型是标准的马尔可夫决策过程(Markov Decision Process, MDP)。按给定条件,强化学习可分为基于模式的强化学习(model-based RL)和无模式强化学习(model-free RL),以及主动强化学习(active RL)和被动强化学习(passive RL)。强化学习的变体包括逆向强化学习、阶层强化学习和部分可观测系统的强化学习。求解强化学习问题所使用的算法可分为策略搜索算法和值函数(value function)算法两强化学习理论受到行为主义心理学启发,侧重在线学习并试图在探索-利用(exploration-exploitation)间保持平衡。不同于监督学习和非监督学习,强化学习不要求预先给定任何数据,而是通过接收环境对动作的奖励(反馈)获得学习信息并更新模型参数。强化学习问题在信息论、博弈论、自动控制等领域有得到讨论,被用于解释有限理性条件下的平衡态、设计推荐系统和机器人交互系统。一些复杂的强化学习算法在一定程度上具备解决复杂问题的通用智能,可以在围棋和电子游戏中达到人水平。 强化学习在工程领域的应用也相当广泛。例如,Facebook提出了开源强化学习平台Horizon,该平台利用强化学习来优化大规模生产系统。在医疗保健领域,RL系统能够为患者提供治疗策略,该系统能够利用以往的经验找到最优的策略,而无需生物系统的数学模型等先验信息,这使得基于RL的系统具有更广泛的适用性。 总的来说,强化学习是一种通过智能体与环境交互,以最大化累积奖励为目标的学习过程。它在许多领域都展现出了强大的应用潜力。
模型可以参考ResNet等经典模型,以下是一个简单的卷积神经网络设计: 1. 输入层:输入32x32x3的图像数据。 2. 第一层卷积:使用64个3x3的卷积核,步长为1,padding为same,激活函数为ReLU。 3. 第二层卷积:使用64个3x3的卷积核,步长为1,padding为same,激活函数为ReLU。 4. 第一层池化:使用2x2的最大池化。 5. 第三层卷积:使用128个3x3的卷积核,步长为1,padding为same,激活函数为ReLU。 6. 第四层卷积:使用128个3x3的卷积核,步长为1,padding为same,激活函数为ReLU。 7. 第二层池化:使用2x2的最大池化。 8. 第五层卷积:使用256个3x3的卷积核,步长为1,padding为same,激活函数为ReLU。 9. 第六层卷积:使用256个3x3的卷积核,步长为1,padding为same,激活函数为ReLU。 10. 第七层卷积:使用256个3x3的卷积核,步长为1,padding为same,激活函数为ReLU。 11. 第三层池化:使用2x2的最大池化。 12. 全连接层:将输出展平成一维向,连接一个512个神经元的全连接层,激活函数为ReLU。 13. 输出层:连接一个10个神经元的全连接层,激活函数为softmax。 参数优化可以采用Adam优化器,损失函数采用交叉熵损失函数。训练时可以采用数据增强技术,如随机裁剪、随机翻转等,以减小过拟合。同时可以使用学习率衰减技术,如每个epoch结束时将学习率除以10,以提高模型的稳定性和泛化能力。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

极客程序设计

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值