利用tiling以及AVX-512指令集进行优化
对于矩阵乘法的优化,最重要的两个目标为提高并行性以及提高局部性,对于并行性的提高可采用多线程以及SIMD指令集对矩阵乘法优化,而对于局部性的提高,tiling则是最重要的优化之一
tiling的目的:确保一个 Cache Line 在被用过以后,后面再用的时候其仍然在 cache 中,没有被 evict
实现思路:当一个数组总的数据量无法 fit in cache时,把总数据分成一个个 tile 去访问,令每个 tile 都可以 fit in the Cache
具体做法:把一层内层循环分成 outer loop * inner loop。然后把 outer loop 移到更外层去,从而确保 inner loop 一定能 fit in cache
tiling实现矩阵乘法
void mmm(int* a, int* b, int* c, int n) {
int i, j, k, i1, j1, k1;
for (i = 0; i < n; i += B)
for (j = 0; j < n; j += B)
for (k = 0; k < n; k += B)
/* B x B mini matrix multiplications */
for (i1 = i; i1 < i + B; i1++)
for (j1 = j; j1 < j + B; j1+=16)
for (k1 = k; k1 < k + B; k1++)
c[i1 * n + j1] += a[i1 * n + k1] * b[k1 * n + j1];
采用AVX512指令集优化后矩阵乘法代码
DWORD ThreadProc(LPVOID IpParam)//线程函数,用于计算矩阵乘法
{
MYDATA* pmd = (MYDATA*)IpParam;
int* A = pmd->A, * B = pmd->B, * C = pmd->C;
int block = pmd->block;
int begin = pmd->begin, end = pmd->end;
int i, j, i1, j1, k1;
for (i = 0; i < N; i += block)
{
for (j = 0; j < N; j += block)
{
/* B x B mini matrix multiplications */
for (i1 = i; i1 < i + block; i1++)
{
for (j1 = j; j1 < j + block; j1++)
{
__m512i res_vec = _mm512_setzero_si512();
for (k1 = begin; k1 < begin + block; k1+=16)
{
C[i1 * begin + j1] += A[i1 * N + k1] * B[j1 * N + k1];
__m512i m1_vec = _mm512_loadu_si512((__m512i*)&A[i * N + k1]);
__m512i m2_vec = _mm512_loadu_si512((__m512i*)&B[j * N + k1]);
res_vec = _mm512_add_epi32(res_vec, _mm512_mullo_epi32(m1_vec, m2_vec));
}
int* p1 = (int*)&res_vec;
C[i * N + j] += (p1[0] + p1[1] + p1[2] + p1[3] + p1[4] + p1[5] + p1[6] + p1[7] + p1[8] + p1[9] + p1[10] + p1[11] + p1[12] + p1[13] + p1[14] + p1[15]);
}
}
}
}
return 0;
}
运行速度
矩阵规模 | 1线程 | 2线程 | 4线程 | 8线程 | 16线程 | 32线程 | 64线程 |
---|---|---|---|---|---|---|---|
1024*1024 | 0.448s | 0.254s | 0.207s | 0.139s | 0.103s | 0.13s | 0.19s |
2048*2048 | 4…008s | 2.023s | 1.203s | 0.901s | 0.707s | 0.74s | 0.909s |
4096*4096 | 33.612s | 19.876s | 9.974s | 6.6371s | 4.866s | 5.257s | 6.139s |
通过多线程以及AVX-512指令集和tiling对程序优化后,对比python基准程序同样可实现上千倍优化。
完整代码
#include <windows.h>
#include <iostream>
#include <ctime>
#include <xmmintrin.h>
#include "immintrin.h"
#define N 4096
#define MAX_THREADS 7
using namespace std;
struct MYDATA //用于传递多个参数给线程函数
{
int begin, end, block;
int* A, * B, * C;
};
DWORD ThreadProc(LPVOID IpParam)//线程函数,用于计算矩阵乘法
{
MYDATA* pmd = (MYDATA*)IpParam;
int* A = pmd->A, * B = pmd->B, * C = pmd->C;
int block = pmd->block;
int begin = pmd->begin, end = pmd->end;
int i, j, i1, j1, k1;
for (i = 0; i < N; i += block)
{
for (j = 0; j < N; j += block)
{
/* B x B mini matrix multiplications */
for (i1 = i; i1 < i + block; i1++)
{
for (j1 = j; j1 < j + block; j1++)
{
__m512i res_vec = _mm512_setzero_si512();
for (k1 = begin; k1 < begin + block; k1+=16)
{
C[i1 * begin + j1] += A[i1 * N + k1] * B[j1 * N + k1];
__m512i m1_vec = _mm512_loadu_si512((__m512i*)&A[i * N + k1]);
__m512i m2_vec = _mm512_loadu_si512((__m512i*)&B[j * N + k1]);
res_vec = _mm512_add_epi32(res_vec, _mm512_mullo_epi32(m1_vec, m2_vec));
}
int* p1 = (int*)&res_vec;
C[i * N + j] += (p1[0] + p1[1] + p1[2] + p1[3] + p1[4] + p1[5] + p1[6] + p1[7] + p1[8] + p1[9] + p1[10] + p1[11] + p1[12] + p1[13] + p1[14] + p1[15]);
}
}
}
}
return 0;
}
int* A = new int[N * N];
int* B = new int[N * N];
int C[N * N];
int* revB = new int[N * N];
double rate[MAX_THREADS];
int threads[MAX_THREADS] = { 1,2,4,8,16,32,64 };
int main()
{
for (int i = 0; i < N * N; i++)
{
A[i] = rand() % 2;
B[i] = rand() % 2;
}
for (int index = 0; index < N * N; index++)
{
int i = index / N, j = index % N;
revB[i * N + j] = B[j * N + i];
}
for (int i = 0; i < MAX_THREADS; i++)
{
int m = threads[i];
clock_t start, end;
start = clock();
HANDLE hThread[N];//初始化临界区
static MYDATA mydt[N];
int temp = N / m;
for (int i = 0; i < m; i++)
{
mydt[i].A = A, mydt[i].B = revB, mydt[i].C = C, mydt[i].block = temp;
mydt[i].begin = i * temp, mydt[i].end = i * temp + temp;
if (i == m - 1)//最后一个线程
{
mydt[i].end = N * N;
}
hThread[i] = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE)ThreadProc, &mydt[i], 0, NULL);//创建线程
}
WaitForMultipleObjects(m, hThread, TRUE, INFINITE);
end = clock();
cout << m << "个线程时运行时间为" << (double)(end - start) / CLOCKS_PER_SEC << "s" << endl;
}
}