5min看懂torch.einsum()计算方法-torch.einsum()手动推导详解

引言: torch.einsum()的分析和介绍已经有很多博客介绍过了, 但大多数的落脚点都是爱因斯坦求和约定,许多篇幅是用于介绍爱因斯坦求和约定到的各项法则,而实际案例分析方面只是草草给出一笔带过,涉及到的案例也较为简单。而实际我们要用到或者看到torch.einsum()的时候往往是在计算非常复杂的情况下。
因此本文将从实际复杂案例的角度对torch.einsum()的计算过程进行分析,一步一步的推导最终输出的每个元素和输入元素之间的关系。

爱因斯坦求和约定

 首先,torch.einsum()的基础原理是爱因斯坦求和约定,此处为了行文的整体性将对其进行简要的介绍,如果只关注计算本身,可以跳到下一节。爱因斯坦求和约定是为了简化计算而诞生的一种“记法”,就类似于我们用 × \times ×来标记乘法一样,不同之处在于爱因斯坦求和约定可表示的运算更为复杂、灵活性也更高。爱因斯坦求和约定的典型写法为:
i 1 i 2 . . . i N , j 1 j 2 . . . j M → i k 1 i k 2 . . j l 1 j l 1 , k 1 . . . ∈ N , l 1 . . l ∈ M i_1i_2...i_N,j_1j_2...j_M\rightarrow i_{k_1}i_{k_2}..j_{l_1}j_{l_1},k_1...\in N,l_1..l\in M i1i2...iN,j1j2...jMik1ik2..jl1jl1,k1...N,l1..lM

其中左端 i 1 i 2 . . . i N , j 1 j 2 . . . j M i_1i_2...i_N,j_1j_2...j_M i1i2...iN,j1j2...jM就表示了输入两个矩阵元素的坐标索引,右端 i k 1 i k 2 . . j l 1 j l 1 i_{k_1}i_{k_2}..j_{l_1}j_{l_1} ik1ik2..jl1jl1为输出矩阵元素的坐标索引,可以看到输出矩阵元素索引相较于输入端的索引可能会缺少几项,运算就是发生这几个维度上的乘累加操作。

其中同时出现在左端和右端的坐标索引为自由索引,只用于标记位置;而仅仅出现在右端的索引为求和索引,爱因斯坦求和约定的本质就是沿着求和索引的方向计算两个输入逐元素乘累加和的结果放到输出自由索引的位置上,更为细致的介绍参见:一文学会 Pytorch 中的 einsumeinsum:爱因斯坦求和约定
举例而言:
i j , j k → i k ij,jk\rightarrow ik ij,jkik
就表示沿着 j j j这个维度进行乘累加操作:
O i k = ∑ j A i j B j k O_{ik}=\sum_{j}A_{ij}B_{jk} Oik=jAijBjk
输出的第 ( i , k ) (i,k) (i,k)个元素为 A i ⋅ A_{i \cdot} Ai的行向量和 B ⋅ k B_{\cdot k} Bk列向量逐元素乘累加,实际上就是矩阵相乘。

复杂案例推导

 正如第一节中所介绍的,torch.enisum()的核心计算过程就是沿着只在算式右边出现的轴对输入矩阵元素进行乘累加得到对应位置的输出元素。因此,想要弄清一个复杂的torch.eisum()表达式含义需要做的也只是将这个求和公式写出来再仔细分析。

案例. 四维张量乘三维张量

 给出一个复杂案例:
n c j t , n p j − > n c p t ncjt,npj->ncpt ncjt,npj>ncpt
则其输出元素可以表示为:
C n c p t = ∑ j A n c j t B n p j C_{ncpt}=\sum_j A_{ncjt}B_{npj} Cncpt=jAncjtBnpj
 首先我们可以注意到对于C的第一维 n n n而言,它同时出现在A和B的首位,也就是对于这一维的每个元素,都是会逐元素的执行A和B剩余维度的计算再在当前维度上排布,用深度学习中的描述来说就是对BATCH中的每个元素都独立的执行后续子操作,子操作可以记为:
C c p t = ∑ j A c j t B p j C_{cpt}=\sum_j A_{cjt}B_{pj} Ccpt=jAcjtBpj
紧接着,对当前算式的第一维 c c c来说它只出现在 A A A中,每沿着 c c c计算一个不同的元素都要和“相同”的B计算,也就出现了广播机制,B有了个隐藏的、元素重复的维度 c c c,计算变为 C c p t = ∑ j A c j t B c p j C_{cpt}=\sum_j A_{cjt}B_{cpj} Ccpt=jAcjtBcpj,同第一步计算的原理,这里又可以化简成逐元素的子操作:
C p t = ∑ j A j t B p j = ∑ j B p j A j t C_{pt}=\sum_jA_{jt}B_{pj}=\sum_jB_{pj}A_{jt} Cpt=jAjtBpj=jBpjAjt
此时易看出 ( p , t ) (p,t) (p,t)元素就是B的第 p p p行向量和A的第 t t t列向量求内积。
 从而我们可以得出结论,这一表达式的意思是,对于BATCH内的每个元素(A‘三维,B’二维),对B在第一维度进行广播(A’‘三维,B’‘三维),最后沿着第二维和第三维计算矩阵相乘B’‘‘A’’‘(A’‘‘二维,B’’'二维)。

而整个的推导过程可以总结为以下几要点

  1. 沿着维数较高输入的第一维开始,判断是否存在于B中,如果在的话就可认为是逐元素操作,暂时忽略该维度;
  2. 如果该维度指示不在B中,则进行广播操作,重新回到1,否则3.
  3. 判断当前最简表达式的意义。
<think>好的,我现在需要详细解析CIOU_loss损失函数,包括它的工作原理、公式推导、实现方法以及应用场景。首先,我得回忆一下CIOU_loss的相关知识,可能结合之前学过的目标检测中的损失函数。 CIOU_loss应该是在IOU损失的基础上改进而来的。之前了解的IOU损失直接使用交集与并集的比例,但它在预测框和真实框没有重叠时梯度消失,而且无法考虑框的长宽比等信息。之后有GIOU,DIOU,然后是CIOU。CIOU_loss可能综合考虑了重叠面积、中心点距离、长宽比等因素。 根据引用中的信息,比如引用[2]提到RIoU和Focal-EIoU等损失函数关注高IoU样本的梯度,但不够简洁和泛化。而CIOU可能属于更早的改进版本,结合了更多的几何因素。引用[4]中提到的梯度调整,可能和CIOU中的某些参数有关,比如alpha或beta参数调整不同质量样本的梯度贡献。 接下来,我需要整理CIOU的公式推导。CIOU应该包括三个部分:IOU损失项、中心点距离的惩罚项,以及长宽比的惩罚项。具体来说,公式可能类似于: CIOU_loss = 1 - IOU + (中心点距离的平方 / 对角线距离的平方) + (长宽比的惩罚项) 其中,长宽比惩罚项可能涉及预测框和真实框的长宽比之差,用某个参数来调节。例如,v可能表示长宽比的一致性,比如v = (4/π²)(arctan(w_true/h_true) - arctan(w_pred/h_pred))²,然后用一个系数乘以v。 然后,公式推导需要详细写出每个部分的意义,以及如何结合起来优化边界框回归。比如,中心点距离的惩罚项帮助预测框更快地靠近真实框的中心,而长宽比的惩罚项让预测框调整比例更接近真实框。 实现方法部分,可能需要用代码展示如何计算CIOU_loss。比如,在PyTorch中,首先计算两个框的IOU,然后计算中心点坐标的欧氏距离,再计算对角线长度,接着计算长宽比的惩罚项v,最后将这些部分组合起来得到最终的损失值。需要注意的是,在计算梯度时,各个部分对梯度的贡献是否平衡,是否有需要调整的参数,比如引用[4]中的beta参数可能影响不同样本的梯度权重。 应用场景方面,CIOU_loss适用于需要精确边界框回归的任务,如目标检测中的Anchor-based方法,尤其是当物体的长宽比变化较大时,CIOU能更好地调整预测框的形状。比如在YOLO系列模型中,可能使用CIOU_loss作为回归损失函数,以提升检测精度。 现在要确认是否正确。例如,CIOU的公式确实包括IOU项、中心点距离项,以及长宽比项。根据记忆,CIOU的公式是: CIOU_Loss = 1 - IOU + (ρ²(b_pred, b_true)/c²) + αv 其中,ρ²是中心点之间的欧氏距离的平方,c是覆盖两个框的最小闭合区域的对角线长度,α是权重系数,v是长宽比一致性的度量。 其中,v的具体计算可能涉及预测框和真实框的宽高比差异。例如,v = (4/π²)(arctan(w_true/h_true) - arctan(w_pred/h_pred))²,而α = v / (1 - IOU + v),这样α作为自适应权重,平衡不同部分的贡献。 在实现时,需要注意计算各个部分的数值稳定性,比如避免除以零,或者处理IOU为零的情况。此外,代码中可能需要计算每个预测框与真实框的这些参数,并对所有样本的损失求平均或其他形式的聚合。 应用场景方面,CIOU_loss通常用于目标检测模型的训练阶段,特别是在需要高精度定位的任务中。例如,YOLOv4、YOLOv5等模型可能采用CIOU_loss作为回归损失的一部分,或者与其他损失函数结合使用,如分类损失和置信度损失结合。 可能还需要比较CIOU_loss与其他损失函数的优劣。比如,相比GIOU,CIOU加入了长宽比的惩罚项,能更全面地调整预测框的形状,因此在某些情况下效果更好,尤其是当物体的宽高比很重要时。但计算复杂度可能稍高,因为需要计算额外的项。 关于公式推导,我需要更仔细地写出每个步骤。例如: 1. 计算IOU:即两个框的交集面积除以并集面积。 2. 计算中心点距离的平方:即 (x_pred - x_true)^2 + (y_pred - y_true)^2。 3. 计算对角线长度c的平方:即覆盖两个框的最小闭合矩形的对角线长度的平方。 4. 计算长宽比的惩罚项v:这里可能需要计算宽高比的差异,例如通过arctan函数将宽高比转换为角度,然后比较预测框和真实框的角度差异,再通过系数调整。 这样,CIOU_loss综合了这三个部分,使得优化过程中不仅考虑重叠面积,还考虑中心点位置和形状相似性,从而更准确地回归边界框。 在实现时,可能需要注意以下几点: - 如何高效地计算IOU,尤其是对于大量边界框的情况。 - 如何处理数值稳定性,比如当c²非常小时,中心点距离项的归一化是否会导致数值问题。 - 如何计算长宽比的惩罚项,避免出现负值或不可导的情况。 最后,应用场景部分需要具体举例,比如在COCO数据集上的目标检测任务,或者特定的应用如自动驾驶中的车辆检测,其中精确的边界框定位是关键。 现在需要将这些思路整理成结构化的回答,包括公式的详细推导、实现方法的代码示例,以及应用场景的具体例子。同时,需要引用提供的参考资料中的相关内容,比如引用[2]和[4]中提到的梯度调整和损失函数设计原则,可能和CIOU中的某些设计有关联。</think>### CIOU_loss 损失函数详细解析 #### **1. 工作机制与公式推导** CIOU(Complete Intersection over Union)损失函数在目标检测中用于优化边界框回归,综合考虑了**重叠区域**、**中心点距离**和**长宽比一致性**三部分因素。其公式为: $$ \mathcal{L}_{\text{CIOU}} = 1 - \text{IOU} + \frac{\rho^2(b_{\text{pred}}, b_{\text{true}})}{c^2} + \alpha v $$ 其中: 1. **$\text{IOU}$**:预测框与真实框的交并比,衡量重叠面积。 2. **$\rho^2$**:预测框与真实框中心点的欧氏距离平方,惩罚中心偏移。 3. **$c^2$**:覆盖两框的最小闭合区域的对角线长度平方,归一化中心点距离。 4. **$v$**:长宽比一致性惩罚项,定义为: $$ v = \frac{4}{\pi^2} \left( \arctan\left(\frac{w_{\text{true}}}{h_{\text{true}}}\right) - \arctan\left(\frac{w_{\text{pred}}}{h_{\text{pred}}}\right) \right)^2 $$ 通过$\arctan$函数将宽高比转换为角度差异,反映形状相似性。 5. **$\alpha$**:自适应权重系数,平衡长宽比项的影响: $$ \alpha = \frac{v}{(1 - \text{IOU}) + v} $$ #### **2. 实现方法(PyTorch示例)** ```python import torch def ciou_loss(pred_boxes, true_boxes): # 计算预测框和真实框的坐标 pred_x1, pred_y1, pred_x2, pred_y2 = pred_boxes.T true_x1, true_y1, true_x2, true_y2 = true_boxes.T # 计算交集面积和IOU inter_x1 = torch.max(pred_x1, true_x1) inter_y1 = torch.max(pred_y1, true_y1) inter_x2 = torch.min(pred_x2, true_x2) inter_y2 = torch.min(pred_y2, true_y2) inter_area = (inter_x2 - inter_x1).clamp(0) * (inter_y2 - inter_y1).clamp(0) pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1) true_area = (true_x2 - true_x1) * (true_y2 - true_y1) union_area = pred_area + true_area - inter_area iou = inter_area / (union_area + 1e-7) # 计算中心点距离 pred_center_x = (pred_x1 + pred_x2) / 2 pred_center_y = (pred_y1 + pred_y2) / 2 true_center_x = (true_x1 + true_x2) / 2 true_center_y = (true_y1 + true_y2) / 2 center_distance = (pred_center_x - true_center_x)**2 + (pred_center_y - true_center_y)**2 # 计算最小闭合区域对角线长度c² c_x1 = torch.min(pred_x1, true_x1) c_y1 = torch.min(pred_y1, true_y1) c_x2 = torch.max(pred_x2, true_x2) c_y2 = torch.max(pred_y2, true_y2) c_squared = (c_x2 - c_x1)**2 + (c_y2 - c_y1)**2 + 1e-7 # 避免除以零 # 计算长宽比惩罚项v pred_wh = torch.stack([pred_x2 - pred_x1, pred_y2 - pred_y1], dim=1) true_wh = torch.stack([true_x2 - true_x1, true_y2 - true_y1], dim=1) v = (4 / (torch.pi**2)) * (torch.atan(true_wh[:, 0]/true_wh[:, 1]) - torch.atan(pred_wh[:, 0]/pred_wh[:, 1]))**2 # 计算自适应权重α alpha = v / (1 - iou + v + 1e-7) # 组合CIOU损失 ciou = 1 - iou + (center_distance / c_squared) + alpha * v return ciou.mean() ``` #### **3. 应用场景** 1. **高精度目标检测**:如自动驾驶中车辆/行人检测,需精确框定物体位置[^2]。 2. **小目标检测**:长宽比一致性惩罚项能改善小目标的形状回归[^4]。 3. **Anchor-based模型**:YOLOv4、YOLOv5等模型常用CIOU_loss替代MSE损失,提升定位精度[^3]。 #### **4. 对比其他损失函数** | 损失函数 | 优点 | 缺点 | |----------|------|------| | **IOU** | 直接反映重叠面积 | 无重叠时梯度为零,忽略形状 | | **GIOU** | 解决无重叠问题 | 未考虑长宽比 | | **DIOU** | 加速收敛,优化中心距离 | 仍忽略形状差异 | | **CIOU** | 综合形状、位置、重叠 | 计算稍复杂 | #### **5. 梯度特性** CIOU通过$\alpha v$动态调整梯度权重,对高质量样本(高IOU)赋予更高梯度,提升回归精度。例如,当$\beta=0.8$时,能有效抑制低质量样本的干扰。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值