浮点卷积winograd算法

目录

winograd算法简介

知识直通车

winograd代码实例解析


winograd算法简介

现今的Winograd主要来源于1980年,由Shmuel Winograd提出减少FIR滤波器计算量的方法

Shmuel Winograd指出,对于输出个数为m,有r个参数的FIR滤波器,不需要m*r次乘法计算,而仅仅需要:

次乘法计算即可。

下面是一个F(2,3)的例子,即输出m=2个结果,参数r=3个:

可以看到,只需要4次乘法计算结果,发生在m1/m2/m3/m4的计算过程中,外加8次加法。注意,跟g相关的加减乘除可以在初始化时一口气算好,不占计算量。

扩展到二维:

上式中,g为参数,d为数据,G/B/A都是转换矩阵,Y是输出

具体的公式推荐大家直接看论文

对于不同的卷积核和输出大小,winograd的转换矩阵各自不同。本文专注于conv3x3_s1的介绍。

 

知识直通车

知乎文章:https://zhuanlan.zhihu.com/p/72149270

winograd论文:https://arxiv.org/abs/1509.09308

github实例代码(python):https://github.com/Sejudyblues/cs194-winograd/blob/master/winograd.py

 

winograd代码实例解析

#include <stdio.h>
#include <assert.h>
#include <chrono>
using namespace std::chrono;

const float x[8][8] = {       // input X
    { 1.0,     0.0,    0.5,   0.0,   5.0,   0.4,   0.4,   0.8},    // X矩阵
    {-2.0,    -2.0,   -2.0,  -2.0,  -2.0,  -2.0,  -2.0,  -2.0},
    {-2.0,     2.0,    2.1,   3.5,   6.0,   2.0,   4.4,   2.0},
    { 1.0,     1.0,    1.0,   1.0,   1.0,   6.5,   1.4,   1.0},
    { 1.0,    -1.0,   -1.2,  -1.5,  -7.6,  -1.0,  -1.2,  -1.2},
    { 1.0,     1.0,    1.0,   1.7,   1.6,   4.4,   1.6,   1.5},
    { 1.0,    -1.0,   -3.5,  -6.0,  -1.3,  -1.0,  -1.6,  -7.0},
    { 0.0,     0.0,    0.0,   0.6,   0.2,   9.0,   0.4,   0.3}
};

const float weight[9] = {         // Weight w
    1.2, 0.63, 0.4,
    0.2,  0.5, 0.56,
    1.5,  0.3, 0.74
};

//U=GgG^T
const float ktm[8][3] = {       // Acutually Matrix G
    {   1.0f,        0.0f,        0.0f},    // G矩阵
    {-2.0f/9,     -2.0f/9,     -2.0f/9},
    {-2.0f/9,      2.0f/9,     -2.0f/9},
    {1.0f/90,     1.0f/45,     2.0f/45},
    {1.0f/90,    -1.0f/45,     2.0f/45},
    {1.0f/45,     1.0f/90,    1.0f/180},
    {1.0f/45,    -1.0f/90,    1.0f/180},
    {   0.0f,        0.0f,        1.0f}
};

int main()
{
    //w^T
    const float* k0 = weight;
    const float* k1 = weight+3;
    const float* k2 = weight+6;

    //U^T = G.g^T.G^T

    float tmp[8][3];    // tmp = G.g^T
    for(int i=0; i<8; i++)
    {
        tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2];
        tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2];
        tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2];
    }
    
    //U^T = tmp.G^T
    float U[64];
    for (int j=0; j<8; j++)
    {
        float* tmpp = &tmp[j][0];

        for (int i=0; i<8; i++)
        {
            U[j*8 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2];
        }
    }
//         const float itm[8][8] = {
//             {1.0f,  0.0f, -5.25f,  0.00f,  5.25f,  0.00f, -1.0f, 0.0f},
//
//             {0.0f,  1.0f,  1.00f, -4.25f, -4.25f,  1.00f,  1.0f, 0.0f},
//             {0.0f, -1.0f,  1.00f,  4.25f, -4.25f, -1.00f,  1.0f, 0.0f},
//
//             {0.0f,  0.5f,  0.25f, -2.50f, -1.25f,  2.00f,  1.0f, 0.0f},
//             {0.0f, -0.5f,  0.25f,  2.50f, -1.25f, -2.00f,  1.0f, 0.0f},
//
//             {0.0f,  2.0f,  4.00f, -2.50f, -5.00f,  0.50f,  1.0f, 0.0f},
//             {0.0f, -2.0f,  4.00f,  2.50f, -5.00f, -0.50f,  1.0f, 0.0f},
//
//             {0.0f, -1.0f,  0.00f,  5.25f,  0.00f, -5.25f,  0.0f, 1.0f}
//         };

        // 0 = r00 - r06 + (r04 - r02) * 5.25
        // 7 = r07 - r01 + (r03 - r05) * 5.25

        // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05)
        // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05)

        // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2)
        // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2)

        // reuse r04 * 1.25
        // reuse r03 * 2.5
        // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5)
        // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5)

    //V^T = B^T.d^T.B
    float tempv[8][8];
    for (int m=0; m<8; m++)
    {
        const float *r0 = &x[m][0];

        tempv[0][m] = r0[0] - r0[6] + (r0[4] - r0[2]) * 5.25f;
        tempv[7][m] = r0[7] - r0[1] + (r0[3] - r0[5]) * 5.25f;

        float tmp12a = (r0[2] + r0[6] - r0[4] * 4.25f);
        float tmp12b = (r0[1] + r0[5] - r0[3] * 4.25f);

        tempv[1][m] = tmp12a + tmp12b;
        tempv[2][m] = tmp12a - tmp12b;

        float tmp34a = (r0[6] + r0[2] * 0.25f - r0[4] * 1.25f);
        float tmp34b = (r0[1] * 0.5f - r0[3] * 2.5f + r0[5] * 2.f);

        tempv[3][m] = tmp34a + tmp34b;
        tempv[4][m] = tmp34a - tmp34b;

        float tmp56a = (r0[6] + (r0[2] - r0[4] * 1.25f) * 4.f);
        float tmp56b = (r0[1] * 2.f - r0[3] * 2.5f + r0[5] * 0.5f);

        tempv[5][m] = tmp56a + tmp56b;
        tempv[6][m] = tmp56a - tmp56b;
    }

    float V[64];
    for (int m=0; m<8; m++)
    {
        const float* tmp0 = tempv[m];

        V[m*8] = tmp0[0] - tmp0[6] + (tmp0[4] - tmp0[2]) * 5.25f;
        V[m*8+7] = tmp0[7] - tmp0[1] + (tmp0[3] - tmp0[5]) * 5.25f;

        float tmp12a = (tmp0[2] + tmp0[6] - tmp0[4] * 4.25f);
        float tmp12b = (tmp0[1] - tmp0[3] * 4.25f + tmp0[5]);

        V[m*8+1] = tmp12a + tmp12b;
        V[m*8+2] = tmp12a - tmp12b;

        float tmp34a = (tmp0[6] + tmp0[2] * 0.25f - tmp0[4] * 1.25f);
        float tmp34b = (tmp0[1] * 0.5f - tmp0[3] * 2.5f + tmp0[5] * 2.f);

        V[m*8+3] = tmp34a + tmp34b;
        V[m*8+4] = tmp34a - tmp34b;

        float tmp56a = (tmp0[6] + (tmp0[2] - tmp0[4] * 1.25f) * 4.f);
        float tmp56b = (tmp0[1] * 2.f - tmp0[3] * 2.5f + tmp0[5] * 0.5f);

        V[m*8+5] = tmp56a + tmp56b;
        V[m*8+6] = tmp56a - tmp56b;
    }

    //Matrix M^T=U^T.V^T
    float tempResult[64];
    for(int n=0;n<64;n++)
    {
        tempResult[n] = V[n]*U[n]; 
    }

//         const float otm[6][8] = {
//             {1.0f,  1.0f,   1.0f,   1.0f,   1.0f,  32.0f, 32.0f, 0.0f},
//             {0.0f,  1.0f,  -1.0f,   2.0f,  -2.0f,  16.0f,-16.0f, 0.0f},
//             {0.0f,  1.0f,   1.0f,   4.0f,   4.0f,   8.0f,  8.0f, 0.0f},
//             {0.0f,  1.0f,  -1.0f,   8.0f,  -8.0f,   4.0f, -4.0f, 0.0f},
//             {0.0f,  1.0f,   1.0f,  16.0f,  16.0f,   2.0f,  2.0f, 0.0f},
//             {0.0f,  1.0f,  -1.0f,  32.0f, -32.0f,   1.0f, -1.0f, 1.0f}
//         };

        // 0 = r0 + (r1 + r2) + (r3 + r4)     + (r5 + r6) * 32
        // 1 =      (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16
        // 2 =      (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8
        // 3 =      (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4
        // 4 =      (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2
        // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6)

    //R = A^T.M.A
    float* pTempR = tempResult;
    float r[6][8];
    for (int m=0; m<8; m++)
    {
        float tmp024a = pTempR[1] + pTempR[2];
        float tmp135a = pTempR[1] - pTempR[2];

        float tmp024b = pTempR[3] + pTempR[4];
        float tmp135b = pTempR[3] - pTempR[4];

        float tmp024c = pTempR[5] + pTempR[6];
        float tmp135c = pTempR[5] - pTempR[6];

        r[0][m] = pTempR[0] + tmp024a + tmp024b + tmp024c * 32;
        r[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8;
        r[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c;

        r[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16;
        r[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4;
        r[5][m] = pTempR[7] + tmp135a + tmp135b * 32 + tmp135c;

        pTempR+=8;
    }

    float result[64]; 
    float* pr = result;
    for (int m=0; m<6; m++)
    {
        const float* tmp0 = r[m];

        float tmp024a = tmp0[1] + tmp0[2];
        float tmp135a = tmp0[1] - tmp0[2];

        float tmp024b = tmp0[3] + tmp0[4];
        float tmp135b = tmp0[3] - tmp0[4];

        float tmp024c = tmp0[5] + tmp0[6];
        float tmp135c = tmp0[5] - tmp0[6];

        pr[0+6*m] = tmp0[0] + tmp024a + tmp024b + tmp024c * 32;
        pr[2+6*m] = tmp024a + tmp024b * 4 + tmp024c * 8;
        pr[4+6*m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c;

        pr[1+6*m] = tmp135a + tmp135b + tmp135b + tmp135c * 16;
        pr[3+6*m] = tmp135a + tmp135b * 8 + tmp135c * 4;
        pr[5+6*m] = tmp0[7] + tmp135a + tmp135b * 32 + tmp135c;
    }

    float output[36];
    int stride_h = 1;
    int stride_w = 1;
    for(int h=0;h<6;h++)
    {
        int src_h = stride_h*h;
        for(int w=0;w<6;w++)
        {
            float sum = 0;
            int src_w = stride_w*w;
            for(int k_h=0;k_h<3;k_h++)
            {
                for(int k_w=0;k_w<3;k_w++)
                {
                    sum+=x[k_h+src_h][k_w+src_w]*weight[3*k_h+k_w];
                }
            }
            output[h*6+w] = sum;
        }
    }

    for(int i=0;i<36;i++)
        printf("%f==%f \n",output[i],result[i]);

    return 0;
}

 

评论 22
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值