挖来~~ 通过修改SelfAttention的执行逻辑,就可以节省大量的激活值显存开销。

通过修改SelfAttention的执行逻辑,可以节省大量的激活值显存开销。

这篇文章的消除方法来自于2021年12月10日谷歌放到arxiv上的文章self attention does not need O(n^2) memory. 该方法巧妙地使用了小学学到的加法分配率,将self attention中的固定激活值降到了O(1)的程度。[1]

Self Attention 固定激活值显存分析

 

PyTorch~固定激活值显存分析与优化_人工智能

Hugging face Transformers中,SelfAttention 内核实现

PyTorch~固定激活值显存分析与优化_Soft_02

 表格中只列举了会实测中产生激活值的操作,其中B为Batch_size,L为sequence_length,H为hidden_size,m为SelfAttention中head的数量。

PyTorch~固定激活值显存分析与优化_初始化_03

SelfAttention 固定激活值显存优化
1. Prerequisites

1.1 Softmax 计算过程

PyTorch~固定激活值显存分析与优化_人工智能_04

写成伪代码则为:

"""
inputs: Q[L][H/m], K[L][H/m], V[L][H/m]
outputs: O[L][H/m]

matrix A[L][L]=0, S[L][L]=0, O[L][H/m]=0 # 初始化为0矩阵, A,S为中间激活值矩阵
"""

# QK Matmul
for i in range(L):
    for j in range(L):
        for l in range(H/m):
            A[i][j] += Q[i][l]*Q[l][j]

# Softmax, dim=-1
for i in range(L):
    temp = 0
    for j in range(L):
        S[i][j] = math.exp(A[i][j])
        temp += S[i][j]
    S[i]/=temp

# OV Matmul
for i in range(L):
    for j in range(H/m):
        for l in range(L):
            O[i][j] += S[i][l]*Q[l][j]

return O
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
2. 显存优化

Google采用了一个非常简单的方法来节省Attention核中的大量的显存开销,具体计算过程为:

PyTorch~固定激活值显存分析与优化_CUDA_05

来避开原始的实现中所产生的A和S矩阵。

写成伪代码:

"""
Inputs: Q[L][H/m], K[L][H/m], V[L][H/m]
outputs: O[L][H/m]

matrix O[L][H/m]=0 # 初始化为0矩阵
"""

for i in range(L): # O row, Q row
        sum_s = 0
        for j in range(L): # O column, K^T column, V row
            a_ij = 0
            for k in range(H/m): # Q column, K^T row
                a_ij += Q[i][k]*K[k][j] # Q_i K_j matmul
            a_ij = a_ij / math.sqrt(H) # scale
            s_ij_prime = math.exp(a_ij) # softmax numerator
            sum_s_i += s_prime_ij # softmax denominator of i-th row
            for oj in range(H/m): # broacast along V column axis
                if random.uniform(0,1) > 0.1: # dropout
                    O[i][oj] += s_ij_prime * V[j][oj] # attention weight, V matmul
        O[i][:] = O[i][:] / sum_s # attention weight, V matmul 
return O
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.

一个可行的PyTorch api实现,但是效率很低很低,不可能用的。效率想要高估计还是需要用CUDA去写个算子...按照文章的说法,实现的好的话,推断的时候是可以比原始方法要快的,但是就训练而言,这里在后向过程中肯定需要进行丢失信息的重计算,论文里可以预见的会被原始方法慢两倍。   

key_layer = key_layer.transpose(-1, -2)

outputs = torch.zeros([1, self.num_attention_heads, 512, 64])
for i in range(512):  # sequence length
    Qi = torch.narrow(query_layer, 2, i, 1)  # (1, 16, 1, 64)
    sum_s = torch.zeros([1, self.num_attention_heads, 1, 1])
    outputs_i = torch.narrow(outputs, 2, i, 1)  # (1, 16, 1, 64)

    for j in range(512): 
        Kj = torch.narrow(key_layer, 3, j, 1)  # (1, 16, 64, 1)
        A_ij = torch.matmul(Qi, Kj) / math.sqrt(self.attention_head_size)  # (1, 16, 1, 1)
        s_ij_prime = torch.exp(A_ij)
        sum_s.add(s_ij_prime)
        V_j = torch.narrow(value_layer, 2, j, 1)  # (1, 16, 1, 64) jth_row
        if random.uniform(0,1) > 0.1:
            outputs_i.add(s_ij_prime.mul(V_j))  # (1, 16, 1, 64)
     outputs_i.div(sum_s)

outputs = outputs.permute(0, 2, 1, 3).contiguous()
outputs_shape = outputs.size()[
                        :-2] + (self.all_head_size,)
outputs = outputs.view(*outputs_shape)
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.

PyTorch~固定激活值显存分析与优化_CUDA_06

 不想写CUDA又想要提升性能的话,可以考虑narrow的时候多取几行或者几列,跟GPU的核数对应上应该比较合适(文章里是4096,也忒大了),然后换成einsum的张量乘法实现可调整遍历窗口大小的优化方法。

总结
  • 这个方法跟原始方法在逻辑上是等价的,而且计算复杂度也是一致的。

PyTorch~固定激活值显存分析与优化_Soft_07

  • 使用的时候需要注意在计算指数的时候可能会存在的溢出问题(这个原始实现里也有),因此文章里面的实现在做指数运算前减去了最大的A_ij值。
  • 收敛性相同,且在训练小Transformer时有4个百分点的速度提升。
  • 需要在Backward的时候重计算丢失掉的信息,这里可能会影响到dropout,所以dropout的结果我猜肯定在前向的时候是不能被丢弃的。
  • 推理系统的福音,可以调整并降低中间产生的激活值峰值,同时保证一定的推理速度。