- 正文前感谢昇腾各位工作人员,没有你们的辛勤就没有我们的进步
- 本文立意交流大赛FastGeluGrad算子编译过程
- 这道题是在FastGelu基础上的升级题目,嗯,不难,公式很吓人
![cke_2160.png](https://img-blog.csdnimg.cn/img_convert/7d733f5573d6b934f2c40a084b433cf3.png)
- 算子要求fp16,fp32 ,计算类型符合所有的API公式限制,也就是说不需要类型转换
- 比FastGelu多了一个输入dy,两个输入,一个输出,并不需要大幅度的数据搬迁
- 唯一需要解决的就是compute的算法设计
__aicore__ inline void Compute(int32_t progress) {
LocalTensor<DTYPE_X> inLocal = inQueueIN.DeQue<DTYPE_X>();
LocalTensor<DTYPE_DY> dyLocal = inLocal[0];
LocalTensor<DTYPE_X> xLocal = inLocal[this->tileLength];
LocalTensor<DTYPE_Z> outLocal = outQueueOUT.AllocTensor<DTYPE_Z>();
LocalTensor<DTYPE_Z> tempTensor1 = calcBuf.Get<DTYPE_Z>();
LocalTensor<DTYPE_Z> tempTensor2 = calcBuf1.Get<DTYPE_Z>();
Abs(tempTensor1, xLocal, this->tileLength);
Muls(tempTensor1, tempTensor1, (DTYPE_Z)(-1.702), this->tileLength);
Exp(tempTensor1, tempTensor1, this->tileLength);//exp(-1.702abs(x))
Adds(outLocal, tempTensor1, (DTYPE_Z)(1), this->tileLength);
Mul(outLocal, outLocal, outLocal, this->tileLength);//分母
Abs(tempTensor2, xLocal, this->tileLength);
Sub(tempTensor2, xLocal, tempTensor2, this->tileLength);//x-abs(x)
Muls(tempTensor2, tempTensor2, (DTYPE_Z)(1.702), this->tileLength);//1.702(x-abs(x))
Exp(tempTensor2, tempTensor2, this->tileLength);//exp(1.702(x-abs(x)))
Add(tempTensor2, tempTensor2, tempTensor1,this->tileLength);//exp(-1.702abs(x)) + exp(1.702(x-abs(x)))
Mul(tempTensor1, tempTensor1, xLocal, this->tileLength);//xexp(-1.702abs(x))
Muls(tempTensor1, tempTensor1, (DTYPE_Z)(1.702), this->tileLength);//1.702xexp(-1.702abs(x))
Add(tempTensor2, tempTensor2, tempTensor1,this->tileLength);//exp(-1.702abs(x)) + exp(1.702(x-abs(x))) + 1.702xexp(-1.702abs(x))
Mul(tempTensor2, dyLocal, tempTensor2, this->tileLength);
Div(outLocal, tempTensor2, outLocal, this->tileLength);
outQueueOUT.EnQue<DTYPE_Z>(outLocal);
inQueueIN.FreeTensor(inLocal);
}
复制
![cke_23916.png](https://img-blog.csdnimg.cn/img_convert/57696a3bc31c124e0a03713549d9cb12.png)
![cke_28815.png](https://img-blog.csdnimg.cn/img_convert/2ab3c0fc44ddff968199036da886471a.png)
![cke_38187.png](https://img-blog.csdnimg.cn/img_convert/a9acb01e287d52890babd3b7b5bc0ef6.png)
![cke_46160.png](https://img-blog.csdnimg.cn/img_convert/0c36cf82718c612256b2a1b8bd12a015.png)
![cke_53388.png](https://img-blog.csdnimg.cn/img_convert/84bc0e218134c51a2647d206364aec96.png)
- 在测试案例中会遇到RuntimeWarning: invalid value encountered in true_divide
- 这个Python警告通常在进行浮点数除法时出现。当被除数或除数的值为无效的特殊浮点数(如NaN或inf)时,就会发出该警告
- 或者是某些情况下,数据的值可能超出了浮点数的表示范围,导致除法运算产生无效值。
- 这里直接忽略了这个warnings