一文理解RetNet

作者:梁德澎  | 来源:GiantPandaCV

前言

微软研究院最近提出了一个新的 LLM 自回归基础架构 Retentive Networks (RetNet)[1,4],该架构相对于 Transformer 架构的优势是同时具备:训练可并行、推理成本低和良好的性能,不可能三角。

c1669a68977f94de932142c64171dc55.png

论文中给出一个很形象的示意图,RetNet 在正中间表示同时具备三个优点,而其他的架构 Linear Transformer、Recurrent Network 和 Transformer 都只能同时具备其中两个有点。

接下来看一下论文给出的 RetNet 和 Transformer 的对比实验结果:

3ec5e4ffe752a951d7c3f78eebf634a3.png

当输入序列长度增加的时候,RetNet 的 GPU 显存占用一直是稳定的和权值差不多,而 Transformer 则是和输入长度成正比。

d3c5d130837b97e1b1dd0272c5b20e3e.png

首先看红色线和紫色线,都是输入长度在 8192 下,RetNet 和 Transformer 推理延时的对比。

可以看到当 batch size 增加的时候, RetNet 的推理延时也还是很稳定,而 Transformer 的推理延时则是和 batch size 成正比。

而 Transformer 即使是输入长度缩小到 1024 ,推理延时也还是比 RetNet 要高。

RetNet 架构解读

RetNet 架构和 Transformer 类似,也是堆叠 层同样的模块,每个模块内部包含两个子模块:一个 multi-scale retention(MSR)和一个 feed-forward network (FFN)。

下面详细解读一下这个 retention 子模块。

首先给定一个输入序列 :

其中 表示序列的长度。然后输入序列首先经过 embedding 层得到词嵌入向量:

其中 表示隐含层的维度。

Retention 机制

首先对给定输入词嵌入向量序列 中的每个时间步 的向量 都乘以权值 得到 :

然后同样有类似 Transformer 架构的 Q 和 K 的投影:

其中 是需要学习的权值。

接着假设现在有一个序列建模的问题,通过状态 将 映射为 向量。首先来看论文中给出的映射方式定义:

其中 是一个矩阵, 表示时间步 对应的 投影则 。同样 表示时间步 对应的 投影。

那么上面公式中的 计算公式是怎么得出来呢,下面详细解释一下,首先将 展开:

其中 表示单位矩阵(主对角线元素为1,其余元素为0的方阵)。然后我们假定 为初始状态元素为全0的矩阵,则有:

再继续上述推导过程:

所以根据上述推导过程和条件归纳可得:

然后我们来看一下 矩阵是什么,论文中定义了 是一个可对角化的矩阵,具体定义为:

其中 都是 维的向量, 是一个可逆矩阵,而要理解 首先得复习一下欧拉公式 [2]

其中 表示任意实数, 是自然对数的底数, 是复数中的虚数单位,也可以表示为实部 ,虚部 的一个复数,欧拉公式[2]建立了指数函数、三角函数和复数之间的桥梁。

而这里 是一个 维向量:

则 也就是将向量元素两两一组表示分别表示为复数的实部和虚部:

然后 就是一个对角矩阵,对角元素的值就对应将 和 转成复数向量相乘再将结果转回实数向量的结果。

关于复数向量相乘可以参考文章: 

一文看懂 LLaMA 中的旋转式位置编码(Rotary Position Embedding)

现在我们知道了矩阵 的构成就能得到:

这里因为 是可逆矩阵则有性质

其中 为单位矩阵,则将 次方展开:

就是 个 矩阵相乘,中间相邻的 都消掉了,所以可得:

然后我们回到计算 的公式:

接着论文中提出把 吸收进 和 也就是 和 分别用 和 替代当作学习的权值,那么可得:

接着将公式简化,将 改为一个实数常量,那么可得:

在继续推导前,先来仔细看一下 ,借助欧拉公式展开:

然后复习一下三角函数的性质[3]

则有:

转为复数形式表示就是:

刚好就对应 的共轭

所以可得:

其中 表示共轭转置操作。

Retention 的训练并行表示

首先回顾单个时间步 的输出 的计算公式如下:

而所有时间步的输出是可以并行计算的,用矩阵形式表达如下:

其中 ,而 表示两个矩阵逐元素相乘, 和 每一行对应一个时间步的 q 和 k 向量。

而 每一行对应向量 。 就是对应 矩阵的共轭,也就是将 矩阵每一行改为复数的共轭形式。

而 矩阵是一个下三角矩阵,其中第 行第 列的元素计算方式:

Retention 的推理循环表示

推理阶段的循环表示论文中定义如下:

怎么理解呢,还是先回顾单个时间步 的输出 的计算公式:

上述公式最后一步和推理阶段循环表示公式中各个元素的对应关系是:

对应论文中的图示:ca26f519d2b4e0d9dc3a8a69f265446c.png

图中的 表示 GroupNorm。

可以看到在推理阶段,RetNet 在计算当前时间步 的输出 只依赖于上一个时间步产出的状态矩阵 。

其实就是把计算顺序改了一下,先计算的 和 的相乘然后一直累加到状态矩阵 上,最后再和 相乘。

而不是像 Transformer 架构那样,每个时间步的计算要先算 和前面所有时间步的 相乘得到 attention 权值再和 相乘求和,这样就需要一直保留历史的 和 。

Gated Multi-Scale Retention

然后 RetNet 每一层中的 Retention 子模块其实也是分了 个头,每个头用不同的 参数,同时每个头都采用不同的 常量,这也是  Multi-Scale Retention 名称的来由。

则对输入 , MSR 层的输出是:

其中, , 是激活函数用来生成门控阈值,还有由于每个头均采用不同的 ,所以每个头的输出要单独做 normalize 之后再 concat。

参考资料

  • [1] https://arxiv.org/pdf/2307.08621.pdf

  • [2] https://en.wikipedia.org/wiki/Euler's_formula

  • [3] https://en.wikipedia.org/wiki/List_of_trigonometric_identities

  • [4] https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py

—END—

高效学习3D视觉三部曲

第一步 加入行业交流群,保持技术的先进性

目前工坊已经建立了3D视觉方向多个社群,包括SLAM、工业3D视觉、自动驾驶方向,细分群包括:

[工业方向]三维点云、结构光、机械臂、缺陷检测、三维测量、TOF、相机标定、综合群;

[SLAM方向]多传感器融合、ORB-SLAM、激光SLAM、机器人导航、RTK|GPS|UWB等传感器交流群、SLAM综合讨论群;

[自动驾驶方向]深度估计、Transformer、毫米波|激光雷达|视觉摄像头传感器讨论群、多传感器标定、自动驾驶综合群等。

[三维重建方向]NeRF、colmap、OpenMVS、MVSNet等。

[无人机方向]四旋翼建模、无人机飞控等。

除了这些,还有求职、硬件选型、视觉产品落地等交流群。

大家可以添加小助理微信: dddvisiona,备注:加群+方向+学校|公司, 小助理会拉你入群。

e464819a655a7b1df52df361f3df7fee.jpeg
添加小助理微信: dddvisiona,拉你入群
第二步 加入知识星球,问题及时得到解答
2.1 「3D视觉从入门到精通」技术星球

针对3D视觉领域的视频课程(三维重建、三维点云、结构光、手眼标定、相机标定、激光/视觉SLAM、自动驾驶等)、源码分享、知识点汇总、入门进阶学习路线、最新paper分享、疑问解答等进行深耕,更有各类大厂的算法工程人员进行技术指导。与此同时,星球将联合知名企业发布3D视觉相关算法开发岗位以及项目对接信息,打造成集技术与就业、项目对接为一体的铁杆粉丝聚集区,6000+星球成员为创造更好的AI世界共同进步,知识星球入口:「3D视觉从入门到精通」

学习3D视觉核心技术,扫描查看,3天内无条件退款 b6f34b1ca370d7efee9de3ccc2061800.jpeg
高质量教程资料、答疑解惑、助你高效解决问题
2.2 3D视觉岗求职星球

本星球:3D视觉岗求职星球 依托于公众号「3D视觉工坊」和「计算机视觉工坊」、「3DCV」,旨在发布3D视觉项目、3D视觉产品、3D视觉算法招聘信息,具体内容主要包括:

  • 收集汇总并发布3D视觉领域优秀企业的最新招聘信息。

  • 发布项目需求,包括2D、3D视觉、深度学习、VSLAM,自动驾驶、三维重建、结构光、机械臂位姿估计与抓取、光场重建、无人机、AR/VR等。

  • 分享3D视觉算法岗的秋招、春招准备攻略,心得体会,内推机会、实习机会等,涉及计算机视觉、SLAM、深度学习、自动驾驶、大数据等方向。

  • 星球内含有多家企业HR及猎头提供就业机会。群主和嘉宾既有21届/22届/23届参与招聘拿到算法offer(含有海康威视、阿里、美团、华为等大厂offer)。

  • 发布3D视觉行业新科技产品,触及行业新动向。

5666b199dfeb39f138ef252df122e6ac.jpeg
扫码加入,3D视觉岗求职星球,简历投起来
第三步 系统学习3D视觉,对模块知识体系,深刻理解并运行

如果大家对3D视觉某一个细分方向想系统学习[从理论、代码到实战],推荐3D视觉精品课程学习网址:www.3dcver.com

科研论文写作:

[1]国内首个面向三维视觉的科研方法与学术论文写作教程

基础课程:

[1]面向三维视觉算法的C++重要模块精讲:从零基础入门到进阶

[2]面向三维视觉的Linux嵌入式系统教程[理论+代码+实战]

[3]如何学习相机模型与标定?(代码+实战)

[4]ROS2从入门到精通:理论与实战

[5]彻底理解dToF雷达系统设计[理论+代码+实战]

工业3D视觉方向课程:

[1](第二期)从零搭建一套结构光3D重建系统[理论+源码+实践]

[2]保姆级线结构光(单目&双目)三维重建系统教程

[3]机械臂抓取从入门到实战课程(理论+源码)

[4]三维点云处理:算法与实战汇总

[5]彻底搞懂基于Open3D的点云处理教程!

[6]3D视觉缺陷检测教程:理论与实战!

SLAM方向课程:

[1]深度剖析面向机器人领域的3D激光SLAM技术原理、代码与实战

[2]彻底剖析激光-视觉-IMU-GPS融合SLAM算法:理论推导、代码讲解和实战

[3](第二期)彻底搞懂基于LOAM框架的3D激光SLAM:源码剖析到算法优化

[4]彻底搞懂视觉-惯性SLAM:VINS-Fusion原理精讲与源码剖析

[5]彻底剖析室内、室外激光SLAM关键算法和实战(cartographer+LOAM+LIO-SAM)

[6](第二期)ORB-SLAM3理论讲解与代码精析

机器人导航与路径规划

[1]移动机器人规划控制入门与实践:基于Navigation2

视觉三维重建:

[1]彻底搞透视觉三维重建:原理剖析、代码讲解、及优化改进

[2]基于深度学习的三维重建MVSNet系列 [论文+源码+应用+科研]

自动驾驶方向课程:

[1] 深度剖析面向自动驾驶领域的车载传感器空间同步(标定)

[2] 国内首个面向自动驾驶目标检测领域的Transformer原理与实战课程

[3]单目深度估计方法:算法梳理与代码实现

[4]面向自动驾驶领域的3D点云目标检测全栈学习路线!(单模态+多模态/数据+代码)

[5]如何将深度学习模型部署到实际工程中?(分类+检测+分割)

无人机:

[1] 零基础入门四旋翼建模与控制(MATLAB仿真)[理论+实战]

最后

1、3D视觉文章投稿作者招募

2、3D视觉课程(自动驾驶、SLAM和工业3D视觉)主讲老师招募

3、顶会论文分享与3D视觉传感器行业直播邀请

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值