self attension 是transformer 编码器和解码器中共同的一个计算环节,在整个transformer 网络体系中耗费的算力比例占主导。所以节省self attention 的正向和反向的计算时间,就可以加速 transormer 的训练和推理过程。
1,self attention 的数学提炼
两个矩阵乘法,加入一个列向的softmax
input 矩阵:
output 矩阵:
step1:
step2:
step3:
2,cpu 实现self attention
这里的数据类型使用了 float,实际网络中一般采用 fp16,数学过程是相同的;
cpu_self_attention.cpp
#include <stdio.h>
#include <string.h>
#include "cpu_gemm.h"
#include "utils.h"
#include "soft_max.h"
//all matrices are row major.
void cpu_self_attention(float* Q, int ldq,
float* K, int ldk,
float* V, int ldv,
float* S, int lds,
float* P, int ldp,
float* O, int ldo,
int N, int d)
{
gemm_nt(Q, ldq, K, ldk, S, lds, N, N, d);// S = Q*K^t (NxN) = (Nxd) * (dxN)
printf("\nS =\n"); print_matrix(S, N, N, lds);
soft_max_column(P, ldp, S, lds, N, N);// P(NxN) = softmax(S(NxN))
printf("\nP =\n"); print_matrix(S, N, N, lds);
gemm_nn(P, ldp, V, ldv, O, ldo, N, d, N);// O = P*V (Nxd) = (NxN) * (Nxd)
}
cpu_gemm.cpp
#include "cpu_gemm.h"
void gemm_nn(float *A, int lda, //A(M x K) rowMj
float *B, int ldb, //B(K x N) rowMj
float *C, int ldc, //C(M x N) rowMj
int M,
int N,
int K)
{
for(int i=0; i<M; i++)
{
for(int j=0; j<N; j++)
{
float sigma = 0.0;
for(int k=0; k<K; k++)
{
sigma += A[i*lda + k] * B[k*ldb + j];
}
C[i*ldc + j] = sigma;
}
}
}
void gemm_nt(float *A, int lda, //A(M x K) rowMj
float *B, int ldb, //B(N x K) rowMj
float *C, int ldc, //C(M x N) rowMj
int M,
int N,
int K)
{
for(int i=0; i<M; i++)
{
for(int j=0; j<N; j++)
{
float sigma = 0.0;
for(int k=0; k<K; k++)
{
sigma += A[i*lda + k] * B[k + j*ldb];
}
C[i*ldc + j] = sigma;
}
}
}
cpu_softmax_column.cpp
这里使用的是未数值优化的方式,直接按照原始公式计算:
#include "soft_max.h"
void soft_max_column(float *P, int ldp, float* S, int lds, int M, int N)//P = softmax(S) P(i,j) = exp(S(i,j))/sigma(exp(S(r,j))); r=0,1,..,n-1 ;
{
for(int j=0; j<N; j++){
float sigma = 0.0f;
for(int i=0; i<M; i++){
sigma += exp(S[i*lds + j])
}
for(int i=0; i<M; i++){
P[i*ldp + j] = S[i*lds + j]/sigma;
}
}
}
3, gpu 实现 self attention 正向
cuda 实现上述过程:
gpu_self_attention.cu
gpu_gemm.cu
gpu_softmax_column.cu
4,为什么不需要gpu 实现self attention 反向
融合上述过程
5, cpu 版本的flash attention ,对算法做验证
#include <stdio.h>
#include <string.h>
#include <math.h>
#include <iostream>
#include <limits>
#include "cpu_self_attention.h"
#include "utils.h"
#include "cpu_gemm.h"
#define N 32//2048
#define d 16//512
#define M (4*4*4*4)//(128*1024)//128K*4 Bytes, 4 = sizeof(float)
template<typename T>
void set_vector_inf_neg(T* m, int len)
{
float inf_neg = -1*std::numeric_limits<T>::infinity();
std::cout<<"inf_neg = "<< inf_neg<<std::endl;
for(int idx=0; idx<len; idx++)
m[idx] = inf_neg;
}
void load_matrix_block(float* A, int lda, int idx, float* B, int ldb, int row, int col)// copy A(idx, 0) to B(0, 0) B(row x col) default col = d
{
for(int i=0; i<row; i++){//row == Bc, Br
for(int j=0; j<col; j++){//col == d
B[i*ldb + j] = A[idx*row*lda + j];
}
}
}
void load_vector_sgmnt(float* AY, int len_a, int idx, float* BY, int len_cp)
{
for(int i=0; i<len_cp; i++){
BY[i] = AY[idx*len_a + i];
}
}
void rowmax(int MM, int NN, float *S, int lds, float* mij)
{
float inf_neg = -1*std::numeric_limits<float>::infinity();
for(int i=0; i<MM; i++){
float row_max = inf_neg;
for(int j=0; j<NN; j++){
if(row_max<S[i*lds + j])
row_max = S[i*lds + j];
}
mij[i] = row_max;
}
}
void exp_matrix_sub_rowmax(int MM, int NN, float* P, int ldp, float* S, int lds, float* mij)
{
for(int i=0; i<MM; i++){
for(int j=0; j<NN; j++){
P[i*ldp + j] = exp(S[i*lds + j] - mij[i]);
}
}
}
void rowsum(int MM, int NN, float* P, int ldp, float* lij)
{
for(int i= 0; i<MM; i++){
float sigma = 0.0;
for(int j=0; j<NN; j++){
sigma += P[i*ldp + j];
}
lij[i] = sigma;
}
}
void vector_max(int MM, float* m_i, float* m_ij, float* mi_new)
{
for(int i=0; i<MM; i++)
mi_new[i] = (m_i[i]>m_ij[i])? m_i[i] : m_ij[i];
}
void flash_attention_cpu(float* Q, int ldq, float* K, int ldk, float* V, int ldv)
{
//step 01
constexpr int Br = M/(4*d);
constexpr int Bc = (M/(4*d))<d? (M/(4*d)): d;
constexpr int Tr = N/Br;
constexpr int Tc = N/Bc;
std::cout<< "Br ="<<Br<<" Bc = "<<Bc<<" Tr = "<<Tr<<" Tc = "<<Tc<<std::endl;
//step 02
float* O = nullptr;// O(N x d)
int ldo = d;
float* l = nullptr;
float* m = nullptr;
O = (float*)malloc(N*ldo*sizeof(float));// O(N x d)
l = (float*)malloc(N*sizeof(float));
m = (float*)malloc(N*sizeof(float));
memset(O, 0x00, N*ldo*sizeof(float));// O(N x d)
memset(l, 0x00, N*sizeof(float));
set_vector_inf_neg<float>(m, N);
//step 03
// Q => Q_1, Q_2, ..., Q_Tr; Q_i(Br x d)
// K => K_1, K_2, ..., K_Tc; K_j(Bc x d)
// V => V_1, V_2, ..., V_Tc; V_j(Bc x d)
//step 04
// O => O_1, O_2, ..., O_Tr; O_i(Br x d)
// l => l_1, l_2, ..., l_Tr; l_i(Br x 1)
// m => m_1, m_2, ..., m_Tr; m_i(Br x 1)
//step 05
float* K_j = nullptr;
float* V_j = nullptr;
int ldkj = d;
int ldvj = d;
K_j = (float*)malloc(Bc*d*sizeof(float));
V_j = (float*)malloc(Bc*d*sizeof(float));
/ to step 08
float* Q_i = nullptr;// Q_i(Br x d)
float* O_i = nullptr;// O_i(Br x d)
int ldqi = d;
int ldoi = d;
float* l_i = nullptr;// l_i(Br x 1)
float* m_i = nullptr;// m_i(Br x 1)
Q_i = (float*)malloc(Br*d*sizeof(float));
O_i = (float*)malloc(Br*d*sizeof(float));
l_i = (float*)malloc(Br*sizeof(float));
m_i = (float*)malloc(Br*sizeof(float));
/ to step 09
float* S_ij = nullptr;// Sij(Br x Bc)
int ldsij = Bc;
S_ij = (float*)malloc(Br*ldsij*sizeof(float));
/ to step 10.1
float* m_ij = nullptr;
m_ij = (float*)malloc(Br*sizeof(float));
/ to step 10.2
float* P_ij = nullptr;
int ldpij = Bc;
P_ij = (float*)malloc(Br*ldpij*sizeof(float));
/ to step 10.3
float* l_ij = nullptr;
l_ij = (float*)malloc(Br*sizeof(float));
/ to step 11.1
float* mi_new = nullptr;
mi_new = (float*)malloc(Br*sizeof(float));
/
for(int j=0; j<Tc; j++){
//step 06 load Kj, Vj
load_matrix_block(K, ldk, j, K_j, ldkj, Bc, d);// Kj(Bc x d) column all are d;
load_matrix_block(V, ldv, j, V_j, ldvj, Bc, d);
//step 07 for
for(int i=0; i<Tr; i++){
//step 08 load Q_i, O_i, l_i, m_i from HBM to SRAM
load_matrix_block(Q, ldq, i, Q_i, ldqi, Br, d);//Q_i(Br x d)
load_matrix_block(O, ldo, i, O_i, ldoi, Br, d);// O_i(Br x d)
load_vector_sgmnt(l, Br, i, l_i, Br); //l_i(Br x 1)
load_vector_sgmnt(m, Br, i, m_i, Br); // m_i(Br x 1)
//step 09 S_ij = Qi*(K^t)j ; S_ij(Br x Bc)
gemm_nt(Q_i, ldqi,
K_j, ldkj,
S_ij, ldsij,
Br,
Bc,
d);
//step 10
//step 10.1 m~ij = rowmax(S_ij) (Br x 1)
rowmax(Br, Bc, S_ij, ldsij, m_ij);
//step 10.2 P~ij = exp(S_ij - m~ij) (Br x Bc)
// P_ij = S_ij - m_ij
exp_matrix_sub_rowmax(Br, Bc, P_ij, ldpij, S_ij, ldsij, m_ij);
//step 10.3 l~ij = rowsum(P~ij) (Br x 1)
rowsum(Br, Bc, P_ij, ldpij, l_ij);
//step 11
//step 11.1 mi_new = max(mi, mij); (Br x 1)
vector_max(Br, m_i, m_ij, mi_new);
//step 11.2 li_new = exp(mi - mi_new)*li + exp(mij - mi_new)*lij (Br x 1)
}
}
}
int main()
{
int NN, dd;
NN = N;//1024;
dd = d;//64;
int ldq, ldk, ldv;
float *Q_h = nullptr;//(Nxd)
float *K_h = nullptr;//(Nxd)
float *V_h = nullptr;//(Nxd)
ldq = dd;//512
ldk = dd;//512
ldv = dd;//512
Q_h = (float*)malloc(NN*ldq*sizeof(float));
K_h = (float*)malloc(NN*ldk*sizeof(float));
V_h = (float*)malloc(NN*ldv*sizeof(float));
init_matrix(Q_h, NN, dd, ldq, 2025); printf("\nQ_h =\n"); print_matrix(Q_h, NN, dd, ldq);
init_matrix(K_h, NN, dd, ldk, 2027); printf("\nK_h =\n"); print_matrix(K_h, NN, dd, ldk);
init_matrix(V_h, NN, dd, ldv, 2026); printf("\nV_h =\n"); print_matrix(V_h, NN, dd, ldv);
//cpu_self_attention(Q_h, ldq, K_h, ldk, V_h, ldv, NN, dd);
flash_attention_cpu(Q_h, ldq, K_h, ldk, V_h, ldv);
free(Q_h);
free(K_h);
free(V_h);
return 0;
}
存稿防丢,未完待续。。。
6, gpu 实现 flash attention 前向
融合算子
数学原理
cuda 实现
挖坑,未完待续 。。。