c++的矩阵乘法加速trick

最近读RNNLM的源代码,发现其实现矩阵乘法时使用了一个trick,这里描述一下这个trick。

首先是正常版的矩阵乘法(其实是矩阵乘向量)

void matrixXvector(float* destvect, float* srcmatrix, int srcmatrix_rownum, int srcmatrix_colnum, float* srcvect, int srcvect_size){
    for(int row=0;row<srcmatrix_rownum;++row){
        destvect[row]=0;
        for(int col=0;col<srcmatrix_colnum;++col){
            destvect[row]+=srcmatrix[row*srcmatrix_colnum+col]*srcvect[col];
        }
    }
}

就是最简单的for循环,逐行逐列遍历。

接下来是RNNLM中实现的trick版本

void matrixXvector2(float* destvect, float* srcmatrix, int srcmatrix_rownum, int srcmatrix_colnum, float* srcvect, int srcvect_size){
    int row, col;
    float val1, val2, val3, val4;
    float val5, val6, val7, val8;
    
    for(row=0;row<srcmatrix_rownum/8;++row){
        val1 = 0;
        val2 = 0;
        val3 = 0;
        val4 = 0;
        val5 = 0;
        val6 = 0;
        val7 = 0;
        val8 = 0;
        
        for(col=0;col<srcmatrix_colnum;++col){
            val1+=srcmatrix[(row*8+0)*srcmatrix_colnum+col]*srcvect[col];
            val2+=srcmatrix[(row*8+1)*srcmatrix_colnum+col]*srcvect[col];
            val3+=srcmatrix[(row*8+2)*srcmatrix_colnum+col]*srcvect[col];
            val4+=srcmatrix[(row*8+3)*srcmatrix_colnum+col]*srcvect[col];
            val5+=srcmatrix[(row*8+4)*srcmatrix_colnum+col]*srcvect[col];
            val6+=srcmatrix[(row*8+5)*srcmatrix_colnum+col]*srcvect[col];
            val7+=srcmatrix[(row*8+6)*srcmatrix_colnum+col]*srcvect[col];
            val8+=srcmatrix[(row*8+7)*srcmatrix_colnum+col]*srcvect[col];
        }
        
        destvect[row*8+0]+=val1;
        destvect[row*8+1]+=val2;
        destvect[row*8+2]+=val3;
        destvect[row*8+3]+=val4;
        destvect[row*8+4]+=val5;
        destvect[row*8+5]+=val6;
        destvect[row*8+6]+=val7;
        destvect[row*8+7]+=val8;
        
    }
    
    for(row=row*8;row<srcmatrix_rownum;++row){
        for(col=0;col<srcmatrix_colnum;++col){
            destvect[row]+=srcmatrix[row*srcmatrix_colnum+col]*srcvect[col];    
        }
    }
}

对比普通版,trick版把遍历行的for循环分成了8份,同时进行列遍历。

实际测试中,这个trick版比普通版快了接近2倍~这是编译器优化造成的么……?

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值