a linear transformer based on attention pooling
Abstract
Currently, transformer-based neural networks have dominated the field of artificial intelligence, achieving widespread applications in natural language processing, computer vision, image generation, and multimodal domains. Especially since OpenAI (reference) released ChatGPT in 2023, GPT has become the undisputed leader in the language domain. However, from the perspective of first principles, there are significant differences between GPT and human thinking. This is mainly reflected in the following aspects: First, GPT needs to store all historical information, while humans have the ability to forget; second, the complexity of GPT is O ( N 2 ) O(N^2) O(N2), while human reading complexity is O ( N ) O(N) O(N). Given these differences, applying GPT to lifelong learning domains still faces several challenges. This paper proposes a linear transformer model based on attention pooling, aiming to achieve linear complexity and forgetting capabilities while retaining the core attention mechanism of GPT. In summary, this paper introduces a linear transformer model based on attention pooling to achieve the aforementioned goals.
Introduction
Due to its ability to capture long-range dependencies between words, the GPT architecture has achieved great success in natural language processing in recent years. From the mainstream solutions adopted by major companies, it is rare to see models like RNNs and LSTMs being used. RNNs and LSTMs have significant drawbacks in training and deployment, mainly due to their inability to parallelize training and their catastrophic forgetting issues. Although the GPT architecture has achieved such great success, it still has several limitations. First, the principles of GPT are fundamentally different from how human intelligence works. Human intelligence includes a forgetting mechanism, while GPT does not. While the lack of a forgetting mechanism might make GPT more powerful, it poses a fatal problem in lifelong learning, mainly due to the increasing size of stored historical information and the growing computational complexity of retrieving this information. Both factors are positively correlated with inference time.
To address these issues, compressing historical information is an unavoidable step. This paper proposes a attention pooling method to compress historical information for querying by the current token. This method supports both parallel training and causal inference.
In summary, the contributions of this paper are as follows:
- Proposes a linear transformer named WQKV, which uses attention pooling to compress historical information.
- Presents parallel training and recursive causal inference methods for WQKV.
- Conducts ablation experiments on WQKV.
Related Work
There has been significant work in reducing the complexity of transformers from O ( N 2 ) O(N^2) O(N2) to O ( N ) O(N) O(N).
RWKV is inspired by Apple’s Attention Free Transformer. The RWKV architecture is carefully simplified and optimized to be converted into an RNN. Additionally, many techniques, such as TokenShift
and SmallInitEmb
, are used to make RWKV perform comparably to GPT.
In “Were RNNs All We Needed?”, minLSTM and minGRU are proposed by removing nonlinear factors that hinder parallel computation. These RNNs can perform recursive inference and parallel training using the parallel scan algorithm.
However, these methods have some drawbacks. To support large contexts (state variables), they often require a very large number of parameters. Our method does not have this requirement. Moreover, these methods lack a clear connection to the attention mechanism of transformers, and their effectiveness has not been widely validated. In contrast, our method fully inherits the attention mechanism of transformers, with the addition of a attention pooling layer. Therefore, the scalability of our method is guaranteed.
Attention pooling.
Method
A straightforward approach to reducing the complexity of GPT from O ( N 2 ) O(N^2) O(N2)to O ( N ) O(N) O(N) is to compress the queried historical information into a fixed-length vector, so that the current token queries a fixed-size context. Clearly, this process is similar to pooling. Therefore, we designed a softmax-weighted pooling method to compress historical information into a fixed-length vector, followed by the original GPT architecture for attention operations. Below, we describe how this is achieved.
For simplicity, this paper assumes a batch size of 1 and does not use multi-head attention.
The attention pooling used in this paper is described as follow pictures.
Consistent with GPT, for the current sequence X X X with shape [ L , C ] [L, C] [L,C], we obtain Q , K , V Q, K, V Q,K,V with shape [ L , C ] [L, C] [L,C] using matrices W q , W k , W v W_q, W_k, W_v Wq,Wk,Wv. The specific calculations are as follows:
K = W k X V = W v X K = W_kX \\ V = W_vX K=WkXV=WvX
In addition, we use an extra weight matrix W w W_w Ww to obtain W = W w X W = W_wX W=WwX, with shape [ L , D ] [L, D] [L,D]. This weight matrix is used to compress historical information. The attention pooling results are:
K ′ = softmax ( W T ) K V ′ = softmax ( W T ) V K' = \text{softmax}(W^T) K \\ V' = \text{softmax}(W^T) V K′=softmax(WT)KV′=softmax(WT)V
Thus, the shape of K , V K, V K,V is compressed from [