From self-attention 2 flash-attention 数学原理与 cuda 实现优化

self attension 是transformer 编码器和解码器中共同的一个计算环节,在整个transformer 网络体系中耗费的算力比例占主导。所以节省self attention 的正向和反向的计算时间,就可以加速 transormer 的训练和推理过程。

1,self attention 的数学提炼

两个矩阵乘法,加入一个列向的softmax

input   矩阵: \mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbf{R}^{N \times d}

output 矩阵:\mathbf{O} \in \mathbf{R}^{N \times d}

\mathbf{self\ attention\ algorithm:}

        step1:        \mathbf{S} = \mathbf{Q}*\mathbf{K}^t

        step2:        \mathbf{P} = \mathbf{softmax_{column}(S)}

        step3:        \mathbf{O} = \mathbf{P}*\mathbf{V}

2,cpu 实现self attention

这里的数据类型使用了 float,实际网络中一般采用 fp16,数学过程是相同的;

cpu_self_attention.cpp

#include <stdio.h>
#include <string.h>

#include "cpu_gemm.h"
#include "utils.h"
#include "soft_max.h"
//all matrices are row major.

void cpu_self_attention(float* Q, int ldq,
						float* K, int ldk,
						float* V, int ldv,
						float* S, int lds,
						float* P, int ldp,
						float* O, int ldo,
						int N, int d)
{
	gemm_nt(Q, ldq, K, ldk, S, lds, N, N, d);// S = Q*K^t     (NxN) = (Nxd) * (dxN)
					printf("\nS =\n");	print_matrix(S, N, N, lds);
	soft_max_column(P, ldp, S, lds, N, N);// P(NxN) = softmax(S(NxN))
					printf("\nP =\n");	print_matrix(S, N, N, lds);
	gemm_nn(P, ldp, V, ldv, O, ldo, N, d, N);// O = P*V     (Nxd) = (NxN) * (Nxd)
}

cpu_gemm.cpp

#include "cpu_gemm.h"

void gemm_nn(float *A, int lda,		//A(M x K) rowMj
	     	 float *B, int ldb,		//B(K x N) rowMj
	     	 float *C, int ldc,		//C(M x N) rowMj
	      	 int M,
			 int N,
			 int K)
{
	for(int i=0; i<M; i++)
	{
		for(int j=0; j<N; j++)
		{
			float sigma = 0.0;

			for(int k=0; k<K; k++)
			{
				sigma += A[i*lda + k] * B[k*ldb + j];
			}

			C[i*ldc + j] = sigma;
		}
	}
}

void gemm_nt(float *A, int lda,		//A(M x K) rowMj
	     	 float *B, int ldb,		//B(N x K) rowMj
	     	 float *C, int ldc,		//C(M x N) rowMj
	      	 int M,
			 int N,
			 int K)
{
	for(int i=0; i<M; i++)
	{
		for(int j=0; j<N; j++)
		{
			float sigma = 0.0;

			for(int k=0; k<K; k++)
			{
				sigma += A[i*lda + k] * B[k + j*ldb];
			}

			C[i*ldc + j] = sigma;
		}
	}
}

cpu_softmax_column.cpp

这里使用的是未数值优化的方式,直接按照原始公式计算:

#include "soft_max.h"
void soft_max_column(float *P, int ldp, float* S, int lds, int M, int N)//P = softmax(S)  P(i,j) = exp(S(i,j))/sigma(exp(S(r,j)));  r=0,1,..,n-1 ;
{
    for(int j=0; j<N; j++){
        float sigma = 0.0f;

        for(int i=0; i<M; i++){
            sigma += exp(S[i*lds + j])
        }

        for(int i=0; i<M; i++){
            P[i*ldp + j] = S[i*lds + j]/sigma;
        }
    }
}

3, gpu 实现 self attention 正向

cuda 实现上述过程:

gpu_self_attention.cu

gpu_gemm.cu

gpu_softmax_column.cu

4,为什么不需要gpu 实现self attention 反向

融合上述过程

5, cpu 版本的flash attention ,对算法做验证


#include <stdio.h>
#include <string.h>
#include <math.h>
#include <iostream>
#include <limits>

#include "cpu_self_attention.h"
#include "utils.h"
#include "cpu_gemm.h"

#define N	32//2048
#define d	16//512
#define M 	(4*4*4*4)//(128*1024)//128K*4 Bytes,   4 = sizeof(float)

template<typename T>
void set_vector_inf_neg(T* m, int len)
{
	float inf_neg = -1*std::numeric_limits<T>::infinity();
	std::cout<<"inf_neg = "<< inf_neg<<std::endl;

	for(int idx=0; idx<len; idx++)
		m[idx] = inf_neg;
}

void load_matrix_block(float* A, int lda, int idx, float* B, int ldb, int row, int col)// copy A(idx, 0)    to  B(0, 0)  B(row x col)       default col = d
{
	for(int i=0; i<row; i++){//row == Bc, Br
		for(int j=0; j<col; j++){//col == d
			B[i*ldb + j] = A[idx*row*lda + j];
		}
	}
}

void load_vector_sgmnt(float* AY, int len_a, int idx, float* BY, int len_cp)
{
	for(int i=0; i<len_cp; i++){
		BY[i] = AY[idx*len_a + i];
	}
}

void rowmax(int MM, int NN, float *S, int lds, float* mij)
{

	float inf_neg = -1*std::numeric_limits<float>::infinity();

	for(int i=0; i<MM; i++){
		float row_max = inf_neg;

		for(int j=0; j<NN; j++){
			if(row_max<S[i*lds + j])
				row_max = S[i*lds + j];
		}

		mij[i] = row_max;
	}
}

void exp_matrix_sub_rowmax(int MM, int NN, float* P, int ldp, float* S, int lds, float* mij)
{
	for(int i=0; i<MM; i++){
		for(int j=0; j<NN; j++){
			P[i*ldp + j] = exp(S[i*lds + j] - mij[i]);
		}
	}
}

void rowsum(int MM, int NN, float* P, int ldp, float* lij)
{
	for(int i= 0; i<MM; i++){
		float sigma = 0.0;

		for(int j=0; j<NN; j++){
			sigma += P[i*ldp + j];
		}

		lij[i] = sigma;
	}
}

void vector_max(int MM, float* m_i, float* m_ij, float* mi_new)
{
	for(int i=0; i<MM; i++)
		mi_new[i] = (m_i[i]>m_ij[i])? m_i[i] : m_ij[i];
}

void flash_attention_cpu(float* Q, int ldq, float* K, int ldk, float* V, int ldv)
{
	//step 01
	constexpr int Br = M/(4*d);
	constexpr int Bc = (M/(4*d))<d? (M/(4*d)): d;
	constexpr int Tr = N/Br;
	constexpr int Tc = N/Bc;

	std::cout<< "Br ="<<Br<<" Bc = "<<Bc<<" Tr = "<<Tr<<" Tc = "<<Tc<<std::endl;
	//step 02
	float* O = nullptr;// O(N x d)
	int ldo = d;
	float* l = nullptr;
	float* m = nullptr;

	O = (float*)malloc(N*ldo*sizeof(float));// O(N x d)
	l = (float*)malloc(N*sizeof(float));
	m = (float*)malloc(N*sizeof(float));
	memset(O, 0x00, N*ldo*sizeof(float));// O(N x d)
	memset(l, 0x00, N*sizeof(float));
	set_vector_inf_neg<float>(m, N);
	//step 03
	// Q => Q_1, Q_2, ..., Q_Tr; Q_i(Br x d)
	// K => K_1, K_2, ..., K_Tc; K_j(Bc x d)
	// V => V_1, V_2, ..., V_Tc; V_j(Bc x d)

	//step 04
	// O => O_1, O_2, ..., O_Tr; O_i(Br x d)
	// l => l_1, l_2, ..., l_Tr; l_i(Br x 1)
	// m => m_1, m_2, ..., m_Tr; m_i(Br x 1)
	//step 05
	float* K_j = nullptr;
	float* V_j = nullptr;
	int ldkj  = d;
	int ldvj  = d;

	K_j = (float*)malloc(Bc*d*sizeof(float));
	V_j = (float*)malloc(Bc*d*sizeof(float));
	/ to step 08
	float* Q_i = nullptr;// Q_i(Br x d)
	float* O_i = nullptr;// O_i(Br x d)
	int ldqi = d;
	int ldoi = d;
	float* l_i = nullptr;// l_i(Br x 1)
	float* m_i = nullptr;// m_i(Br x 1)

	Q_i = (float*)malloc(Br*d*sizeof(float));
	O_i = (float*)malloc(Br*d*sizeof(float));
	l_i = (float*)malloc(Br*sizeof(float));
	m_i = (float*)malloc(Br*sizeof(float));
	/ to step 09
	float* S_ij = nullptr;// Sij(Br x Bc)
	int ldsij = Bc;
	S_ij = (float*)malloc(Br*ldsij*sizeof(float));
	/ to step 10.1
	float* m_ij = nullptr;
	m_ij = (float*)malloc(Br*sizeof(float));

	/ to step 10.2
	float* P_ij = nullptr;
	int ldpij = Bc;
	P_ij = (float*)malloc(Br*ldpij*sizeof(float));

	/ to step 10.3
	float* l_ij = nullptr;
	l_ij = (float*)malloc(Br*sizeof(float));

	/ to step 11.1
	float* mi_new = nullptr;
	mi_new = (float*)malloc(Br*sizeof(float));

	/
	for(int j=0; j<Tc; j++){
		//step 06 load Kj, Vj
		load_matrix_block(K, ldk, j, K_j, ldkj, Bc, d);// Kj(Bc x d)  column all are d;
		load_matrix_block(V, ldv, j, V_j, ldvj, Bc, d);
		//step 07 for
		for(int i=0; i<Tr; i++){
			//step 08 load Q_i, O_i, l_i, m_i from HBM to SRAM
			load_matrix_block(Q, ldq, i, Q_i, ldqi, Br, d);//Q_i(Br x d)
			load_matrix_block(O, ldo, i, O_i, ldoi, Br, d);// O_i(Br x d)
			load_vector_sgmnt(l, Br, i, l_i, Br);			//l_i(Br x 1)
			load_vector_sgmnt(m, Br, i, m_i, Br);			// m_i(Br x 1)
			//step 09 S_ij = Qi*(K^t)j ; S_ij(Br x Bc)
			gemm_nt(Q_i,  ldqi,
					K_j,  ldkj,
					S_ij, ldsij,
					Br,
					Bc,
					d);
			//step 10
				//step 10.1 m~ij = rowmax(S_ij)			(Br x 1)
			rowmax(Br, Bc, S_ij, ldsij, m_ij);
				//step 10.2 P~ij = exp(S_ij - m~ij)     (Br x Bc)
				// P_ij = S_ij - m_ij
			exp_matrix_sub_rowmax(Br, Bc, P_ij, ldpij, S_ij, ldsij, m_ij);
				//step 10.3 l~ij = rowsum(P~ij)         (Br x 1)
			rowsum(Br, Bc, P_ij, ldpij, l_ij);
			//step 11
				//step 11.1 mi_new = max(mi, mij);		(Br x 1)
			vector_max(Br, m_i, m_ij, mi_new);
				//step 11.2 li_new = exp(mi - mi_new)*li + exp(mij - mi_new)*lij	(Br x 1)



		}
	}



}

int main()
{
	int NN, dd;

	NN = N;//1024;
	dd = d;//64;

	int ldq, ldk, ldv;
	float *Q_h = nullptr;//(Nxd)
	float *K_h = nullptr;//(Nxd)
	float *V_h = nullptr;//(Nxd)

	ldq = dd;//512
	ldk = dd;//512
	ldv = dd;//512

	Q_h = (float*)malloc(NN*ldq*sizeof(float));
	K_h = (float*)malloc(NN*ldk*sizeof(float));
	V_h = (float*)malloc(NN*ldv*sizeof(float));

	init_matrix(Q_h, NN, dd, ldq, 2025);		printf("\nQ_h =\n");	print_matrix(Q_h, NN, dd, ldq);
	init_matrix(K_h, NN, dd, ldk, 2027);		printf("\nK_h =\n");	print_matrix(K_h, NN, dd, ldk);
	init_matrix(V_h, NN, dd, ldv, 2026);		printf("\nV_h =\n");	print_matrix(V_h, NN, dd, ldv);

    //cpu_self_attention(Q_h, ldq, K_h, ldk, V_h, ldv, NN, dd);
	flash_attention_cpu(Q_h, ldq, K_h, ldk, V_h, ldv);

	free(Q_h);
	free(K_h);
	free(V_h);

	return 0;
}

存稿防丢,未完待续。。。

6, gpu 实现 flash attention 前向

融合算子

数学原理

cuda 实现

挖坑,未完待续 。。。

  • 4
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值