斯坦福开源新突破:百行代码助H100性能提升30%

提高 GPU 利用率,就是这么简单。

AI 的快速发展,伴随而来的是大计算量。这就自然而然的引出了一个问题:如何减少 AI 对计算的需求,并提高现有 AI 计算效率。

为了回答这一问题,来自斯坦福的研究者在博客《GPUs Go Brrr》中给出了答案。

图片

文章主要专注于两个问题:一是硬件真正需要什么?二是如何满足硬件需求?

文章用大量篇幅讨论了如何让 GPU 更快的运行,并发布了一个库 ThunderKittens,用户可以很容易地在 CUDA 上编写快速的深度学习内核。其具有以下特点:

  • 简单,ThunderKittens 写起来非常简单。
  • 可扩展性,如果用户需要 ThunderKittens 无法提供的功能,可以进行功能扩展。
  • 速度快。

图片

ThunderKittens 使得一些棘手的事情变得非常简单,从而在现代硬件上实现了非常高的利用率。项目中,作者用 ThunderKittens 编写了一个 RTX 4090 简单的 FlashAttention-2 内核,代码总共有 58 行代码(不包括空格),结果显示,ThunderKittens 在 RTX 4090 上实现了大约 122 TFLOP(理论最大值的 74%)。此外,内核程序只有 100 行的情况下,ThunderKittens 在 H100 上的性能比 FlashAttention-2 高出约 30%。

【一一AGI大模型学习 所有资源获取处一一】

①人工智能/大模型学习路线

②AI产品经理入门指南

③大模型方向必读书籍PDF版

④超详细海量大模型实战项目

⑤LLM大模型系统学习教程

⑥640套-AI大模型报告合集

⑦从0-1入门大模型教程视频

⑧AGI大模型技术公开课名额

英伟达 H100 有些小怪癖

该研究重点关注 NVIDIA H100,不过所介绍的内容也适用于其他 GPU。

图片

H100 SXM GPU 包含:

  • 80 GB HBM3,带宽为 3 TB/s(实际上带宽会少一些);
  • 50 MB 二级缓存,带宽 12 TB/s,在 GPU 上分成两个 25MB 的部分,通过 crossbar 连接;
  • 132 个流多处理器 (SM,streaming multiprocessors)。

除了上述这些,H100 SXM GPU 还有很多可关注的东西,例如内存控制器、指令缓存等。

研究者表示保持张量核心的运行流畅并不容易。他们发现了一些 AI 硬件上的怪癖,这些怪癖中的很多内容也适用于非 H100 GPU,但 H100 尤其棘手。(相比之下,RTX 4090 则非常容易使用),这些怪癖包括:

  • WGMMA 指令是必需的,但使用起来也非常令人恼火;
  • 共享内存实际上并没有那么快,并且需要非常小心;
  • 地址生成成本很高;
  • 占用率仍然有帮助,寄存器通常是关键资源。

图片

文章进一步描述了 GPU 这些怪癖的具体内容。

WGMMA 指令令人恼火

H100 有一组新指令,称为「warp group matrix multiply accumulate,WGMMA」(PTX 中的 wgmma.mma_async,或 SASS 中的 HGMMA/IGMMA/QGMMA/BGMMA)。以前的 GPU 上可用的张量核心指令是 wmma.mma.sync 和 mma.sync 。通过这些指令,SM 单个象限上的 32 个线程将同步地将其数据块馈送到张量核心并等待结果。

不同的是,wgmma.mma_async 指令并非如此,128 个连续线程(分布在 SM 的所有象限中)协作同步,并直接从共享内存(也可以选择寄存器)异步启动矩阵乘法。

在基准测试中,研究团队发现这些指令对于提取 H100 的完整计算是必要的。如果没有它们,GPU 的峰值利用率似乎只能达到峰值利用率的 63% 左右。

在这里插入图片描述

共享内存

共享内存的单次访问延迟约为 30 个周期,这听起来似乎不算多,但在这段时间内,SM 的张量核心几乎可以完成两个完整的 32x32 矩阵乘法运算。

共享内存处理起来有些棘手,因为它被存储(banked)在 32 个独立的内存存储中。如果不小心,这可能会导致所谓的 bank 冲突,即同一内存 bank 被要求同时提供多个不同的内存片段,导致请求被串行化,这可能会不成比例地减慢内核的速度 - 而 wgmma 和 mma 指令所需的寄存器布局会受到这些 bank 冲突的影响。解决方法是使用各种交错模式重新排列共享内存,以避免这些冲突。

地址生成

H100 其中一个特点是张量核心和内存都足够快,以至于仅仅生成用于获取数据的内存地址就占据了芯片资源的相当一部分。

NVIDIA 似乎已经意识到了这一点,因为他们赋予了 GPU 张量内存加速器(或称之为 TMA)。TMA 允许用户在全局和共享内存中指定多维张量布局,这节省了所有的地址生成成本,并且还使得构建 pipeline 更加容易。

研究团队还发现 TMA 和 wgmma.mma_async 一样,在实现 H100 的全部潜力方面是完全不可或缺的。

占用

在某些方面,与前几代硬件相比,H100 对占用率的依赖程度较低。NVIDIA 确实在设计 GPU 时考虑了占用率。虽然对于 H100 来说,占用率只能说有用,但作用不大。研究者发现在 A100 和 RTX 4090 上它变得越来越重要。

ThunderKittens

那么,如何才能更轻松地编写内核,同时仍兼具硬件的全部功能?

研究团队设计了一个嵌入 CUDA 中的 DSL,被命名为 ThunderKittens。

在这里插入图片描述

ThunderKittens 旨在尽可能简单,并包含四种模板类型:

  • 寄存器 tile—— 寄存器文件中的 2D 张量。
  • 寄存器向量 —— 寄存器文件中的 1D 张量。
  • 共享 tile—— 共享内存中的 2D 张量。
  • 共享向量 —— 共享内存中的 1D 张量。

tile 通过高度、宽度和布局进行参数化,寄存器向量由长度和布局参数化,共享向量仅由长度参数化。这样通常不会遭受 bank 冲突的困扰。

研究团队还提供了一些必要操作:

初始化,如将共享向量清零

  • 一元运算,如 exp
  • 二元运算,如 mul
  • 行 / 列操作,如 row_sum

该研究给出了一个用 ThunderKittens 编写的,用于 RTX 4090 的简单前向 flash attention 内核:

#define NUM_WORKERS 16 // This kernel uses 16 workers in parallel per block, to help issue instructions more quickly.
using namespace kittens; // this kernel only handles headdim=64 for simplicity. Also n should be a multiple of 256 here.
__global__ void attend_ker64(int n, const bf16* __restrict__ __q__, const bf16* __restrict__ __k__, const bf16* __restrict__ __v__, bf16* __o__) {

      auto warpid        = kittens::warpid();    
      auto block_start   = blockIdx.x*(n*64);    
      const bf16 *_q = __q__ + block_start, *_k = __k__ + block_start, *_v = __v__ + block_start;          
      bf16 *_o = __o__ + block_start;
      
      extern __shared__ alignment_dummy __shm[]; // this is the CUDA shared memory
      shared_allocator al((int*)&__shm[0]);
      
      // K and V live in shared memory -- this is about all that will fit.
      st_bf_1x4<ducks::st_layout::swizzle> (&k_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>();
      st_bf_1x4<ducks::st_layout::swizzle> (&v_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>();
      
      // Initialize all of the register tiles.
      rt_bf_1x4<> q_reg, k_reg, v_reg; // v_reg need to be swapped into col_l
      rt_fl_1x1<> att_block;
      rt_bf_1x1<> att_block_mma;
      rt_fl_1x4<> o_reg;
      rt_fl_1x1<>::col_vec max_vec_last, max_vec; // these are column vectors for the attention block
      rt_fl_1x1<>::col_vec norm_vec_last, norm_vec; // these are column vectors for the attention block
      
      int qo_blocks = n / (q_reg.rows*NUM_WORKERS), kv_blocks = n / (q_reg.rows*NUM_WORKERS);
      
      for(auto q_blk = 0; q_blk < qo_blocks; q_blk++) {
      
      // each warp loads its own Q tile of 16x64, and then multiplies by 1/sqrt(d)
      load(q_reg, _q + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
      mul(q_reg, q_reg, __float2bfloat16(0.125f)); // temperature adjustment
      
      // zero flash attention L, M, and O registers.
      neg_infty(max_vec); // zero registers for the Q chunk
      zero(norm_vec);
      zero(o_reg);
      
      // iterate over k, v for these q's that have been loaded
      for(auto kv_idx = 0; kv_idx < kv_blocks; kv_idx++) {
      
          // each warp loads its own chunk of k, v into shared memory
          load(v_smem[warpid], _v + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
          load(k_smem[warpid], _k + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
          __syncthreads(); // we need to make sure all memory is loaded before we can begin the compute phase
          
          // now each warp goes through all of the subtiles, loads them, and then does the flash attention internal alg.
          for(int subtile = 0; subtile < NUM_WORKERS; subtile++) {
          
          load(k_reg, k_smem[subtile]); // load k from shared into registers
          
          zero(att_block); // zero 16x16 attention tile
          mma_ABt(att_block, q_reg, k_reg, att_block); // Q@K.T
          
          copy(norm_vec_last, norm_vec);
          copy(max_vec_last,  max_vec);
          
          row_max(max_vec, att_block, max_vec); // accumulate onto the max_vec
          sub_row(att_block, att_block, max_vec); // subtract max from attention -- now all <=0
          exp(att_block, att_block); // exponentiate the block in-place.
          
          sub(max_vec_last, max_vec_last, max_vec); // subtract new max from old max to find the new normalization.
          exp(max_vec_last, max_vec_last); // exponentiate this vector -- this is what we need to normalize by.
          mul(norm_vec, norm_vec, max_vec_last); // and the norm vec is now normalized.
          
          mul(norm_vec, norm_vec, max_vec_last); // and the norm vec is now normalized.
          div_row(att_block, att_block, norm_vec); // now the attention block is correctly normalized
          
          mul(norm_vec_last, norm_vec_last, max_vec_last); // normalize the previous norm vec according to the new max
          div(norm_vec_last, norm_vec_last, norm_vec); // normalize the previous norm vec according to the new norm
          
          copy(att_block_mma, att_block); // convert to bf16 for mma_AB
          
          load(v_reg, v_smem[subtile]); // load v from shared into registers.
          rt_bf_1x4<ducks::rt_layout::col> &v_reg_col = swap_layout_inplace(v_reg); // this is a reference and the call has invalidated v_reg
          
          mul_row(o_reg, o_reg, norm_vec_last); // normalize o_reg in advance of mma_AB'ing onto it
          mma_AB(o_reg, att_block_mma, v_reg_col, o_reg); // mfma onto o_reg with the local attention@V matmul.
      } 
      __syncthreads(); // we need to make sure all warps are done before we can start loading the next kv chunk
   }
   store(_o + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, o_reg, q_reg.cols); // write out o. compiler has an issue with register usage if d is made constexpr q_reg.rows :/
}
}



      

总共大约有 60 行 CUDA 代码,硬件利用率为 75%,虽然非常密集,但大部分复杂性在于算法,而不是混合模式或寄存器布局。

TMA、WGMMA、swizzling 模式和描述符的复杂度又如何呢?如下是用 ThunderKittens 编写的, H100 的 FlashAttention-2 前向传递:

template<int D>
__global__  __launch_bounds__((NUM_WORKERS)*kittens::WARP_THREADS, 2)
void fwd_attend_ker_dim(int N, const CUtensorMap* tma_q, const CUtensorMap* tma_k, const CUtensorMap* tma_v, CUtensorMap* tma_o) {    
extern __shared__ int __shm[]; // this is the CUDA shared memory    
tma_swizzle_allocator al((int*)&__shm[0]);

    constexpr int tile_width = fwd_attend_ker_tile_dims<D>::tile_width; // 
    constants    constexpr int qo_height  = fwd_attend_ker_tile_dims<D>::qo_height;    
    constexpr int kv_height  = fwd_attend_ker_tile_dims<D>::kv_height;
    
    st_bf<qo_height, tile_width, layout_q>          (&q_smem)   [NUM_WARPGROUPS] = al.allocate<st_bf<qo_height, tile_width, layout_q>,          NUM_WARPGROUPS>();    
    st_bf<kv_height, tile_width, layout_k>          (&k_smem)[2][NUM_WORKERS_KV] = al.allocate<st_bf<kv_height, tile_width, layout_k>, 2,       NUM_WORKERS_KV>();    
    st_bf<kv_height, tile_width, layout_v>          (&v_smem)[2][NUM_WORKERS_KV] = al.allocate<st_bf<kv_height, tile_width, layout_v>, 2,       NUM_WORKERS_KV>();
    
    int tic = 0, toc = 1;     
    rt_fl<1, kv_height> att_block;    
    rt_bf<1, kv_height> att_block_mma;    
    rt_fl<1, qo_height> o_prev;    
    col_vec<rt_fl<1, kv_height>> max_vec_last, max_vec;    
    col_vec<rt_fl<1, kv_height>> norm_vec_last, norm_vec;
    
    int warpid      = kittens::warpid();    
    int warpgroupid = warpid/kittens::WARPGROUP_WARPS;
    
    int kv_blocks = N / (NUM_WORKERS_KV*k_smem[0][0].rows);
    __shared__ uint64_t qsmem_barrier, kvsmem_barrier;//, vsmem_barrier;
    int q_phasebit = 0;    
    int kv_phasebit = 0;
    
    if (threadIdx.x == 0) {        
        tma::init_barrier<st_bf<qo_height, tile_width, layout_q>, NUM_WARPGROUPS>(qsmem_barrier, 1);             
        tma::init_barrier<st_bf<kv_height, tile_width, layout_k>, NUM_WORKERS_KV*2>(kvsmem_barrier, 1);     
    }
    if (warpid == 0) {        
        for (int wg = 0; wg < NUM_WORKERS/kittens::WARPGROUP_WARPS; wg++) { // load q            
            int tile_idx = (blockIdx.y * NUM_WARPGROUPS * gridDim.x) + (blockIdx.x * NUM_WARPGROUPS) + wg;            
            tma::load_async((q_smem[wg]), tma_q, qsmem_barrier, tile_idx);         
        }        
        for (int w = 0; w < NUM_WORKERS_KV; w++) { // load k, v                  
            int tile_idx = (blockIdx.y * NUM_WORKERS_KV * kv_blocks) + (0 * NUM_WORKERS_KV) + w;             
            tma::load_async((k_smem[tic][w]), tma_k, kvsmem_barrier, tile_idx);             
            tma::load_async((v_smem[tic][w]), tma_v, kvsmem_barrier, tile_idx);         
        }    
    }
    neg_infty(max_vec); // zero registers for the Q chunk    
    zero(norm_vec);    
    zero(o_prev);    
    __syncthreads();
    
    tma::arrive_and_wait(qsmem_barrier, q_phasebit);    
    q_phasebit ^= 1;
    	
    if constexpr (D == 64) { warpgroup::mul(q_smem[warpgroupid], q_smem[warpgroupid], __float2bfloat16(0.125f)); }     
    else { warpgroup::mul(q_smem[warpgroupid], q_smem[warpgroupid], __float2bfloat16(0.08838834764f)); }
    
    for (auto kv_idx = 0; kv_idx < kv_blocks; kv_idx++, tic ^= 1, toc ^= 1) {        
        tma::arrive_and_wait(kvsmem_barrier, kv_phasebit);        
        kv_phasebit ^= 1;
        
        __syncthreads();        
        if (warpid == 0) {            
           tma::set_bytes(kvsmem_barrier, 2 * NUM_WORKERS_KV * k_smem[0][0].num_elements * sizeof(bf16));
           
           if (kv_idx + 1 < kv_blocks) {                
               for (int w = 0; w < NUM_WORKERS_KV; w++) {                            
                   int tile_idx = (blockIdx.y * NUM_WORKERS_KV * kv_blocks) + ((kv_idx + 1) * NUM_WORKERS_KV) + w;                     
                   tma::load_async((k_smem[toc][w]), 
                   tma_k, kvsmem_barrier, tile_idx);                     tma::load_async((v_smem[toc][w]), tma_v, kvsmem_barrier, tile_idx);                
                   }            
            }        
      }
      warpgroup::mma_fence(att_block);        
      warpgroup::mm_ABt(att_block, q_smem[warpgroupid], k_smem[tic][0]);        
      warpgroup::mma_commit_group();
      
      copy(norm_vec_last, norm_vec);       
      copy(max_vec_last,  max_vec);
      
      warpgroup::mma_async_wait();
      
      row_max(max_vec, att_block, max_vec); // accumulate onto the max_vec        
      sub_row(att_block, att_block, max_vec);        
      exp(att_block, att_block);
      
      sub(max_vec_last, max_vec_last, max_vec);        
      exp(max_vec_last, max_vec_last);        
      mul(norm_vec, norm_vec, max_vec_last);
      
      row_sum(norm_vec, att_block, norm_vec); // accumulate onto the norm_vec        
      div_row(att_block, att_block, norm_vec);
      
      mul(norm_vec_last, norm_vec_last, max_vec_last);        
      div(norm_vec_last, norm_vec_last, norm_vec);
      
      copy(att_block_mma, att_block); // convert to bf16 for mma        
      mul_row(o_prev, o_prev, norm_vec_last); // normalize o_prev in advance of mma'ing onto it
      
      warpgroup::mma_fence(o_prev);        
      warpgroup::mma_AB(o_prev, att_block_mma, v_smem[tic][0]);        
      warpgroup::mma_commit_group();   
  }
      
  auto (*o_smem) = reinterpret_cast<st_bf<qo_height, tile_width, layout_o>(*)>(q_smem); // reuse q memory    
  warpgroup::store(o_smem[warpgroupid], o_prev);     
  __syncthreads();        
  
  if (warpid % 4 == 0) { // store o        
      int tile_idx = (blockIdx.y * NUM_WARPGROUPS * gridDim.x) + (blockIdx.x * NUM_WARPGROUPS) + warpgroupid;        
      tma::store_async(tma_o, (o_smem[warpgroupid]), tile_idx);         
      tma::store_commit_group();     
  }
  	
    tma::store_async_wait();
}

这个内核只有 100 行代码,它在 H100 上的性能比 FlashAttention-2 高出约 30%。ThunderKittens 负责 wrap up 布局和指令,并提供一个可以在 GPU 上使用的 mini-pytorch。

在这里插入图片描述
H100 SXM 上各种配置的 FlashAttention-2(Pytorch)与 ThunderKittens 的比较。

此外,研究团队还发布了基于线性注意力的内核和其他架构。基于线性注意力内核的运行速度为 215 TFLOP(如果考虑算法中固有的重计算,则运行速度超过 300 TFLOP)。

虽然理论上线性注意力更高效,但从实践经验来看,线性注意力在硬件上的效率大大降低。因此,ThunderKittens 有望开辟广泛的高吞吐量应用。

在这里插入图片描述
)使用 ThunderKittens 可以非常快地实现线性注意力。*

tile 看起来是个好点子

在研究团队看来,ThunderKittens 之所以运行良好,是因为它不会试图做所有事情。CUDA 确实比 ThunderKittens 更有表现力,而 ThunderKittens 又小又简单。

不过,ThunderKittens 具有很好的抽象能力,它具有小的 tile,这与 AI 和硬件的发展相匹配。ThunderKittens 不支持任何少于 16 的维数。但在研究团队看来,这一点并不重要,尤其对于硬件而言。如果你的矩阵乘法小于 16x16,你确定自己做的还是 AI 吗?

从哲学的视角来看,研究团队认为框架迁移是合理的。「寄存器」当然不应该像旧 CPU 那样的 32 位。CUDA 使用的 1024 位宽向量寄存器无疑朝着正确方向迈出了一步。但对研究团队而言,「寄存器」是 16x16 的数据 tile。他们认为 AI 想要这样,它仍然只是矩阵乘法、规约和重塑。当然硬件也想要这样,小的矩阵乘法寻求硬件支持,而不仅仅是 systolic mma。

实际上,从更广泛的视角来看,研究团队认为应该围绕硬件的良好映射来重新调整 AI 思路。比如,循环状态应该有多大?SM 能够容纳多大尺寸?计算密度是多少?这些都不亚于硬件的要求。

研究团队表示,这项工作未来的一个重要方向是利用他们对硬件的了解来帮助设计与硬件相匹配的 AI。

最后,AMD 硬件上适配的 ThunderKittens 也将很快推出。

读者福利:如果大家对大模型感兴趣,这套大模型学习资料一定对你有用

对于0基础小白入门:

如果你是零基础小白,想快速入门大模型是可以考虑的。

一方面是学习时间相对较短,学习内容更全面更集中。
二方面是可以根据这些资料规划好学习计划和方向。

包括:大模型学习线路汇总、学习阶段,大模型实战案例,大模型学习视频,人工智能、机器学习、大模型书籍PDF。带你从零基础系统性的学好大模型!

😝有需要的小伙伴,可以保存图片到wx扫描二v码免费领取【保证100%免费】🆓
在这里插入图片描述

👉AI大模型学习路线汇总👈

大模型学习路线图,整体分为7个大的阶段:(全套教程文末领取哈)

第一阶段: 从大模型系统设计入手,讲解大模型的主要方法;

第二阶段: 在通过大模型提示词工程从Prompts角度入手更好发挥模型的作用;

第三阶段: 大模型平台应用开发借助阿里云PAI平台构建电商领域虚拟试衣系统;

第四阶段: 大模型知识库应用开发以LangChain框架为例,构建物流行业咨询智能问答系统;

第五阶段: 大模型微调开发借助以大健康、新零售、新媒体领域构建适合当前领域大模型;

第六阶段: 以SD多模态大模型为主,搭建了文生图小程序案例;

第七阶段: 以大模型平台应用与开发为主,通过星火大模型,文心大模型等成熟大模型构建大模型行业应用。

👉大模型实战案例👈

光学理论是没用的,要学会跟着一起做,要动手实操,才能将自己的所学运用到实际当中去,这时候可以搞点实战案例来学习。

在这里插入图片描述

👉大模型视频和PDF合集👈

观看零基础学习书籍和视频,看书籍和视频学习是最快捷也是最有效果的方式,跟着视频中老师的思路,从基础到深入,还是很容易入门的。
在这里插入图片描述
在这里插入图片描述

👉学会后的收获:👈

• 基于大模型全栈工程实现(前端、后端、产品经理、设计、数据分析等),通过这门课可获得不同能力;

• 能够利用大模型解决相关实际项目需求: 大数据时代,越来越多的企业和机构需要处理海量数据,利用大模型技术可以更好地处理这些数据,提高数据分析和决策的准确性。因此,掌握大模型应用开发技能,可以让程序员更好地应对实际项目需求;

• 基于大模型和企业数据AI应用开发,实现大模型理论、掌握GPU算力、硬件、LangChain开发框架和项目实战技能, 学会Fine-tuning垂直训练大模型(数据准备、数据蒸馏、大模型部署)一站式掌握;

• 能够完成时下热门大模型垂直领域模型训练能力,提高程序员的编码能力: 大模型应用开发需要掌握机器学习算法、深度学习框架等技术,这些技术的掌握可以提高程序员的编码能力和分析能力,让程序员更加熟练地编写高质量的代码。

👉获取方式:

😝有需要的小伙伴,可以保存图片到wx扫描二v码免费领取【保证100%免费】🆓
在这里插入图片描述

  • 6
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值