目录
winograd算法简介
现今的Winograd主要来源于1980年,由Shmuel Winograd提出减少FIR滤波器计算量的方法
Shmuel Winograd指出,对于输出个数为m,有r个参数的FIR滤波器,不需要m*r次乘法计算,而仅仅需要:
次乘法计算即可。
下面是一个F(2,3)的例子,即输出m=2个结果,参数r=3个:
可以看到,只需要4次乘法计算结果,发生在m1/m2/m3/m4的计算过程中,外加8次加法。注意,跟g相关的加减乘除可以在初始化时一口气算好,不占计算量。
扩展到二维:
上式中,g为参数,d为数据,G/B/A都是转换矩阵,Y是输出
具体的公式推荐大家直接看论文
对于不同的卷积核和输出大小,winograd的转换矩阵各自不同。本文专注于conv3x3_s1的介绍。
知识直通车
知乎文章:https://zhuanlan.zhihu.com/p/72149270
winograd论文:https://arxiv.org/abs/1509.09308
github实例代码(python):https://github.com/Sejudyblues/cs194-winograd/blob/master/winograd.py
winograd代码实例解析
#include <stdio.h>
#include <assert.h>
#include <chrono>
using namespace std::chrono;
const float x[8][8] = { // input X
{ 1.0, 0.0, 0.5, 0.0, 5.0, 0.4, 0.4, 0.8}, // X矩阵
{-2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0},
{-2.0, 2.0, 2.1, 3.5, 6.0, 2.0, 4.4, 2.0},
{ 1.0, 1.0, 1.0, 1.0, 1.0, 6.5, 1.4, 1.0},
{ 1.0, -1.0, -1.2, -1.5, -7.6, -1.0, -1.2, -1.2},
{ 1.0, 1.0, 1.0, 1.7, 1.6, 4.4, 1.6, 1.5},
{ 1.0, -1.0, -3.5, -6.0, -1.3, -1.0, -1.6, -7.0},
{ 0.0, 0.0, 0.0, 0.6, 0.2, 9.0, 0.4, 0.3}
};
const float weight[9] = { // Weight w
1.2, 0.63, 0.4,
0.2, 0.5, 0.56,
1.5, 0.3, 0.74
};
//U=GgG^T
const float ktm[8][3] = { // Acutually Matrix G
{ 1.0f, 0.0f, 0.0f}, // G矩阵
{-2.0f/9, -2.0f/9, -2.0f/9},
{-2.0f/9, 2.0f/9, -2.0f/9},
{1.0f/90, 1.0f/45, 2.0f/45},
{1.0f/90, -1.0f/45, 2.0f/45},
{1.0f/45, 1.0f/90, 1.0f/180},
{1.0f/45, -1.0f/90, 1.0f/180},
{ 0.0f, 0.0f, 1.0f}
};
int main()
{
//w^T
const float* k0 = weight;
const float* k1 = weight+3;
const float* k2 = weight+6;
//U^T = G.g^T.G^T
float tmp[8][3]; // tmp = G.g^T
for(int i=0; i<8; i++)
{
tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2];
tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2];
tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2];
}
//U^T = tmp.G^T
float U[64];
for (int j=0; j<8; j++)
{
float* tmpp = &tmp[j][0];
for (int i=0; i<8; i++)
{
U[j*8 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2];
}
}
// const float itm[8][8] = {
// {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f},
//
// {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f},
// {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f},
//
// {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f},
// {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f},
//
// {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f},
// {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f},
//
// {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f}
// };
// 0 = r00 - r06 + (r04 - r02) * 5.25
// 7 = r07 - r01 + (r03 - r05) * 5.25
// 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05)
// 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05)
// 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2)
// 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2)
// reuse r04 * 1.25
// reuse r03 * 2.5
// 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5)
// 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5)
//V^T = B^T.d^T.B
float tempv[8][8];
for (int m=0; m<8; m++)
{
const float *r0 = &x[m][0];
tempv[0][m] = r0[0] - r0[6] + (r0[4] - r0[2]) * 5.25f;
tempv[7][m] = r0[7] - r0[1] + (r0[3] - r0[5]) * 5.25f;
float tmp12a = (r0[2] + r0[6] - r0[4] * 4.25f);
float tmp12b = (r0[1] + r0[5] - r0[3] * 4.25f);
tempv[1][m] = tmp12a + tmp12b;
tempv[2][m] = tmp12a - tmp12b;
float tmp34a = (r0[6] + r0[2] * 0.25f - r0[4] * 1.25f);
float tmp34b = (r0[1] * 0.5f - r0[3] * 2.5f + r0[5] * 2.f);
tempv[3][m] = tmp34a + tmp34b;
tempv[4][m] = tmp34a - tmp34b;
float tmp56a = (r0[6] + (r0[2] - r0[4] * 1.25f) * 4.f);
float tmp56b = (r0[1] * 2.f - r0[3] * 2.5f + r0[5] * 0.5f);
tempv[5][m] = tmp56a + tmp56b;
tempv[6][m] = tmp56a - tmp56b;
}
float V[64];
for (int m=0; m<8; m++)
{
const float* tmp0 = tempv[m];
V[m*8] = tmp0[0] - tmp0[6] + (tmp0[4] - tmp0[2]) * 5.25f;
V[m*8+7] = tmp0[7] - tmp0[1] + (tmp0[3] - tmp0[5]) * 5.25f;
float tmp12a = (tmp0[2] + tmp0[6] - tmp0[4] * 4.25f);
float tmp12b = (tmp0[1] - tmp0[3] * 4.25f + tmp0[5]);
V[m*8+1] = tmp12a + tmp12b;
V[m*8+2] = tmp12a - tmp12b;
float tmp34a = (tmp0[6] + tmp0[2] * 0.25f - tmp0[4] * 1.25f);
float tmp34b = (tmp0[1] * 0.5f - tmp0[3] * 2.5f + tmp0[5] * 2.f);
V[m*8+3] = tmp34a + tmp34b;
V[m*8+4] = tmp34a - tmp34b;
float tmp56a = (tmp0[6] + (tmp0[2] - tmp0[4] * 1.25f) * 4.f);
float tmp56b = (tmp0[1] * 2.f - tmp0[3] * 2.5f + tmp0[5] * 0.5f);
V[m*8+5] = tmp56a + tmp56b;
V[m*8+6] = tmp56a - tmp56b;
}
//Matrix M^T=U^T.V^T
float tempResult[64];
for(int n=0;n<64;n++)
{
tempResult[n] = V[n]*U[n];
}
// const float otm[6][8] = {
// {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f},
// {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f},
// {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f},
// {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f},
// {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f},
// {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f}
// };
// 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32
// 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16
// 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8
// 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4
// 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2
// 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6)
//R = A^T.M.A
float* pTempR = tempResult;
float r[6][8];
for (int m=0; m<8; m++)
{
float tmp024a = pTempR[1] + pTempR[2];
float tmp135a = pTempR[1] - pTempR[2];
float tmp024b = pTempR[3] + pTempR[4];
float tmp135b = pTempR[3] - pTempR[4];
float tmp024c = pTempR[5] + pTempR[6];
float tmp135c = pTempR[5] - pTempR[6];
r[0][m] = pTempR[0] + tmp024a + tmp024b + tmp024c * 32;
r[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8;
r[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c;
r[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16;
r[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4;
r[5][m] = pTempR[7] + tmp135a + tmp135b * 32 + tmp135c;
pTempR+=8;
}
float result[64];
float* pr = result;
for (int m=0; m<6; m++)
{
const float* tmp0 = r[m];
float tmp024a = tmp0[1] + tmp0[2];
float tmp135a = tmp0[1] - tmp0[2];
float tmp024b = tmp0[3] + tmp0[4];
float tmp135b = tmp0[3] - tmp0[4];
float tmp024c = tmp0[5] + tmp0[6];
float tmp135c = tmp0[5] - tmp0[6];
pr[0+6*m] = tmp0[0] + tmp024a + tmp024b + tmp024c * 32;
pr[2+6*m] = tmp024a + tmp024b * 4 + tmp024c * 8;
pr[4+6*m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c;
pr[1+6*m] = tmp135a + tmp135b + tmp135b + tmp135c * 16;
pr[3+6*m] = tmp135a + tmp135b * 8 + tmp135c * 4;
pr[5+6*m] = tmp0[7] + tmp135a + tmp135b * 32 + tmp135c;
}
float output[36];
int stride_h = 1;
int stride_w = 1;
for(int h=0;h<6;h++)
{
int src_h = stride_h*h;
for(int w=0;w<6;w++)
{
float sum = 0;
int src_w = stride_w*w;
for(int k_h=0;k_h<3;k_h++)
{
for(int k_w=0;k_w<3;k_w++)
{
sum+=x[k_h+src_h][k_w+src_w]*weight[3*k_h+k_w];
}
}
output[h*6+w] = sum;
}
}
for(int i=0;i<36;i++)
printf("%f==%f \n",output[i],result[i]);
return 0;
}