RWKV模型详解与开源代码

RWKV模型详解


RWKV (Receptance Weighted Key Value) 是一种结合了RNN(递归神经网络)和Transformer模型优点的新型架构。本文将详细介绍RWKV模型的背景、原理、与现有大模型的对比、实现代码示例以及总结和参考文献。

1. 背景与简介

随着深度学习的发展,Transformer模型在自然语言处理(NLP)任务中展现了卓越的性能。然而,Transformer的内存和计算复杂度随着序列长度的增加呈二次增长,使得其在处理长序列时非常昂贵。相比之下,RNNs在计算和内存需求方面表现出线性增长,但由于并行化能力不足,其性能无法与Transformer媲美。

为了在效率和性能之间取得平衡,RWKV模型应运而生。RWKV模型通过结合Transformer的并行计算能力和RNN的时间序列处理能力,提供了一种具有线性计算复杂度且支持并行化训练的新型架构。

2. RWKV原理与模型介绍

RWKV模型的架构由四个核心元素构成,这些元素在时间混合和通道混合块中发挥关键作用:

  • R:Receptance向量,用于接收来自过去的信息。
  • W:Weight权重向量,表示位置权重衰减向量,是模型中的可训练参数。
  • K:Key向量,类似于传统注意力机制中的Key。
  • V:Value向量,功能类似于传统注意力机制中的Value。

这些核心元素在每个时间步通过乘法进行交互,如下图所示。

2.1 架构

在这里插入图片描述

RWKV模型由堆叠的残差块(Residual Block)组成,每个块包含时间混合(Time-Mixing)和通道混合(Channel-Mixing)子块。这些子块通过引入递归结构,充分利用了过去的信息。

模型使用了一种独特的类注意力得分更新过程,其中包括一个时间依赖的Softmax操作,改进了数值稳定性并减轻了梯度消失问题。这确保了梯度沿着最相关的路径传播。此外,模型中引入的层归一化(Layer Normalization)进一步稳定了梯度,有效解决了梯度消失和爆炸问题。这些设计元素不仅增强了深度神经网络的训练动态,还促进了多层堆叠,从而在不同抽象层次上捕捉复杂模式。

2.2 Token Shift

在该架构中,所有参与计算的线性投影向量(如时间混合中的R, K, V,以及通道混合中的R’, K’)都是通过当前和前一时间步输入的线性插值生成的,促进了Token Shift的实现。

时间混合计算的向量通过以下公式生成:

r t = W r ⋅ ( μ r ⊙ x t + ( 1 − μ r ) ⊙ x t − 1 ) r_t = W_r \cdot (\mu_r \odot x_t + (1 - \mu_r) \odot x_{t-1}) rt=Wr(μrxt+(1μr)xt1)

k t = W k ⋅ ( μ k ⊙ x t + ( 1 − μ k ) ⊙ x t − 1 ) k_t = W_k \cdot (\mu_k \odot x_t + (1 - \mu_k) \odot x_{t-1}) kt=Wk(μkxt+(1μk)xt1)

v t = W v ⋅ ( μ v ⊙ x t + ( 1 − μ v ) ⊙ x t − 1 ) v_t = W_v \cdot (\mu_v \odot x_t + (1 - \mu_v) \odot x_{t-1}) vt=Wv(μvxt+(1μv)xt1)

通道混合的输入也通过类似的方式生成:

r t ′ = W r ′ ⋅ ( μ r ′ ⊙ x t + ( 1 − μ r ′ ) ⊙ x t − 1 ) r'_t = W'_r \cdot (\mu'_r \odot x_t + (1 - \mu'_r) \odot x_{t-1}) rt=Wr(μrxt+(1μr)xt1)

k t ′ = W k ′ ⋅ ( μ k ′ ⊙ x t + ( 1 − μ k ′ ) ⊙ x t − 1 ) k'_t = W'_k \cdot (\mu'_k \odot x_t + (1 - \mu'_k) \odot x_{t-1}) kt=Wk(μkxt+(1μk)xt1)

2.3 WKV操作符

WKV操作符的计算类似于Attention Free Transformer (AFT)中的方法。然而,与AFT不同的是,RWKV模型中的W是一个通道向量,并且受相对位置的影响。在RWKV模型中,这种递归行为由WKV向量的时间依赖更新定义,公式如下:

wkv t = ∑ i = 1 t − 1 e − ( t − 1 − i ) w + k i ⋅ v i + e u + k t ⋅ v t ∑ i = 1 t − 1 e − ( t − 1 − i ) w + k i + e u + k t \text{wkv}_t = \frac{\sum_{i=1}^{t-1} e^{-(t-1-i)w + k_i} \cdot v_i + e^{u+k_t} \cdot v_t}{\sum_{i=1}^{t-1} e^{-(t-1-i)w + k_i} + e^{u+k_t}} wkvt=i=1t1e(t1i)w+ki+eu+kti=1t1e(t1i)w+kivi+eu+ktvt

为了避免W的潜在退化,我们引入了一个U向量,该向量单独关注当前的Token。

2.4 输出门控

输出门控在时间混合和通道混合块中通过Receptance的Sigmoid函数 σ ( r ) \sigma(r) σ(r)实现。WKV操作符后的输出向量 o t o_t ot的计算如下:

o t = W o ⋅ ( σ ( r t ) ⊙ wkv t ) o_t = W_o \cdot (\sigma(r_t) \odot \text{wkv}_t) ot=Wo(σ(rt)wkvt)

在通道混合块中,类似的操作如下:

o t ′ = σ ( r t ′ ) ⊙ ( W v ′ ⋅ max ⁡ ( k t ′ , 0 ) 2 ) o'_t = \sigma(r'_t) \odot (W'_v \cdot \max(k'_t, 0)^2) ot=σ(rt)(Wvmax(kt,0)2)

其中我们采用了平方ReLU激活函数。

2.5 Transformer-like训练

RWKV可以通过一种称为时间并行模式的技术进行高效并行化,这类似于Transformer。在单层中处理一批序列的时间复杂度为 O ( B T d 2 ) O(BTd^2) O(BTd2),主要由矩阵乘法 W λ W\lambda 构成,其中 λ ∈ { r , k , v , o } \lambda \in \{r, k, v, o\} λ{r,k,v,o}(假设有 B B B个序列, T T T个最大Token, d d d个通道)。与此相对,更新注意力得分 w k v t wkv_t wkvt涉及串行扫描,其复杂度为 O ( B T d ) O(BTd) O(BTd)

2.6 RNN-like推理

递归网络通常使用状态 t t t的输出作为状态 t + 1 t+1 t+1的输入。RWKV利用了这种RNN-like结构,在推理阶段可以方便地以递归形式进行解码。

2.7 额外优化

为了解决WKV计算中由任务的顺序特性引起的效率问题,我们开发了一个自定义的CUDA内核。该内核允许在训练加速器上执行单个计算内核,同时模型的其他部分(如矩阵乘法和点操作)已经具有固有的并行性和高效性。

RWKV的实现还包括小初始化嵌入和自定义初始化等优化策略,以加速和稳定训练过程,并在更深的架构中实现更好的收敛性。

3. 与现有大模型对比

在这里插入图片描述

RWKV模型在效率和性能上与现有的主流模型(如GPT、BERT等)进行了全面对比。与Transformer模型相比,RWKV在处理长序列时展现了出色的计算效率,且在同等计算资源下表现出相当的性能。

在推理阶段,RWKV的计算时间和内存需求均为线性增长,显著低于传统Transformer模型的二次增长。以下表格总结了RWKV与其他模型在推理复杂度方面的对比:

模型推理复杂度内存需求
TransformerO(T^2d)O(T^2 + Td)
ReformerO(T log Td)O(T log T + Td)
PerformerO(Td^2 log d)O(Td log d + d^2 log d)
RWKV (本模型)O(Td)O(d)

4. 开源代码(代码来自论文中的开源链接:RWKV

########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import math
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
logger = logging.getLogger(__name__)

########################################################################################################
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################

def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module
    for m in module.modules():
        if not isinstance(m, (nn.Linear, nn.Embedding)):
            continue
        with torch.no_grad():
            name = '[unknown weight]'
            for name, parameter in module.named_parameters(): # find the name of the weight
                if id(m.weight) == id(parameter):
                    break

            shape = m.weight.data.shape
            gain = 1.0  # positive: gain for orthogonal, negative: std for normal
            scale = 1.0 # extra scale for gain

            if isinstance(m, nn.Linear):
                if m.bias is not None:
                    m.bias.data.zero_()
                if shape[0] > shape[1]:
                    gain = math.sqrt(shape[0] / shape[1])
                if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
                    scale = config.rwkv_emb_scale

            if isinstance(m, nn.Embedding):
                gain = math.sqrt(max(shape[0], shape[1]))
                if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
                    scale = config.rwkv_emb_scale

            if hasattr(m, 'scale_init'):
                scale = m.scale_init

            print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)

            gain *= scale
            if gain == 0:
                nn.init.zeros_(m.weight) # zero init is great for some RWKV matrices
            elif gain > 0:
                nn.init.orthogonal_(m.weight, gain=gain)
            else:
                nn.init.normal_(m.weight, mean=0, std=-gain)

class RWKV_TimeMix(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        assert config.n_attn % config.n_head == 0
        self.layer_id = layer_id
        self.ctx_len = config.ctx_len
        self.n_head = config.n_head
        self.head_size = config.n_attn // config.n_head

        with torch.no_grad(): # initial time_w curves for better convergence
            ww = torch.ones(config.n_head, config.ctx_len)
            curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)]) # the distance
            for h in range(config.n_head):
                if h < config.n_head - 1:
                    decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1))
                else:
                    decay_speed = 0.0
                ww[h] = torch.exp(curve * decay_speed)
                # print('layer', layer_id, 'head', h, 'decay_speed', round(decay_speed, 4), ww[h][:5].numpy(), '...', ww[h][-5:].numpy())
        self.time_w = nn.Parameter(ww)

        self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
        self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
        self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
                
        self.time_shift = nn.ZeroPad2d((0,0,1,-1))

        self.key = nn.Linear(config.n_embd, config.n_attn)
        self.value = nn.Linear(config.n_embd, config.n_attn)
        self.receptance = nn.Linear(config.n_embd, config.n_attn)

        # if config.rwkv_tiny_attn > 0:
        #     self.tiny_att = RWKV_TinyAttn(config)

        self.output = nn.Linear(config.n_attn, config.n_embd)

        self.key.scale_init = 0
        self.receptance.scale_init = 0
        self.output.scale_init = 0

    def forward(self, x):
        B, T, C = x.size()
        TT = self.ctx_len
        w = F.pad(self.time_w, (0, TT))
        w = torch.tile(w, [TT])
        w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
        w = w[:, :, TT-1:] # w is now a circulant matrix
        w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]

        x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
        # if hasattr(self, 'tiny_att'):
        #     tiny_att = self.tiny_att(x, self.mask)

        k = self.key(x)
        v = self.value(x)
        r = self.receptance(x)

        k = torch.clamp(k, max=30, min=-60) # clamp extreme values. e^30 = 10^13
        k = torch.exp(k)
        sum_k = torch.cumsum(k, dim=1)

        kv = (k * v).view(B, T, self.n_head, self.head_size)

        wkv = (torch.einsum('htu,buhc->bthc', w, kv)).contiguous().view(B, T, -1)

        rwkv = torch.sigmoid(r) * wkv / sum_k

        rwkv = self.output(rwkv)
        # if hasattr(self, 'tiny_att'):
        #     rwkv += tiny_att

        return rwkv * self.time_gamma[:T, :]

class RWKV_ChannelMix(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.layer_id = layer_id
        self.time_shift = nn.ZeroPad2d((0,0,1,-1))
        
        hidden_sz = 5 * config.n_ffn // 2 # can use smaller hidden_sz because of receptance gating
        self.key = nn.Linear(config.n_embd, hidden_sz)
        self.value = nn.Linear(config.n_embd, hidden_sz)
        self.weight = nn.Linear(hidden_sz, config.n_embd)
        self.receptance = nn.Linear(config.n_embd, config.n_embd)

        self.receptance.scale_init = 0
        self.weight.scale_init = 0

    def forward(self, x):
        B, T, C = x.size()
        
        x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
        k = self.key(x)
        v = self.value(x)
        r = self.receptance(x)
        
        wkv = self.weight(F.mish(k) * v) # i find mish is a bit better than gelu

        rwkv = torch.sigmoid(r) * wkv

        return rwkv

class RWKV_TinyAttn(nn.Module): # extra tiny attention
    def __init__(self, config):
        super().__init__()
        self.d_attn = config.rwkv_tiny_attn
        self.n_head = config.rwkv_tiny_head
        self.head_size = self.d_attn // self.n_head

        self.qkv = nn.Linear(config.n_embd, self.d_attn * 3)
        self.out = nn.Linear(self.d_attn, config.n_embd)

    def forward(self, x, mask):
        B, T, C = x.size()
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim = -1)

        if self.n_head > 1:
            q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2)      # (B, T, C) -> (B, nh, T, hs)
            k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2)      # (B, T, C) -> (B, nh, T, hs)
            v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2)      # (B, T, C) -> (B, nh, T, hs)

        qk = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_size))     # (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
        qk = qk.masked_fill(mask == 0, float('-inf'))
        qk = F.softmax(qk, dim = -1)
        qkv = qk @ v                                                           # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)

        if self.n_head > 1:
            qkv = qkv.transpose(1, 2).contiguous().view(B, T, -1)              # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
       
        return self.out(qkv)

########################################################################################################
# MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN
########################################################################################################

class RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, base=10000):
        super().__init__()
        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x, seq_len=None):
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(seq_len, device=x.device)
            freqs = torch.einsum('i,j->ij', t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.cos_cached = emb.cos()
            self.sin_cached = emb.sin()
        return self.cos_cached, self.sin_cached

def rotate_half(x):
    x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), -1)

@torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin):
    cos, sin = cos[...,:q.shape[-2],:], sin[...,:q.shape[-2],:]
    return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

class MHA_rotary(nn.Module):
    def __init__(self, config, layer_id, time_shift = False):
        super().__init__()
        self.layer_id = layer_id
        assert config.n_attn % config.n_head == 0
        self.n_head = config.n_head
        self.ctx_len = config.ctx_len
        self.head_size = config.n_attn // config.n_head

        if time_shift:
            self.time_shift = nn.ZeroPad2d((0,0,1,-1))

        self.query = nn.Linear(config.n_embd, config.n_attn)
        self.key = nn.Linear(config.n_embd, config.n_attn)
        self.value = nn.Linear(config.n_embd, config.n_attn)

        self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
        
        self.rotary_ndims = int(self.head_size * 0.5)
        self.rotary_emb = RotaryEmbedding(self.rotary_ndims)

        self.output = nn.Linear(config.n_attn, config.n_embd)

    def forward(self, x):
        B, T, C = x.size()

        if hasattr(self, 'time_shift'):
            x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)

        q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)       # (B, T, C) -> (B, nh, T, hs)
        k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)         # (B, T, C) -> (B, nh, T, hs)
        v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)       # (B, T, C) -> (B, nh, T, hs)

        q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
        k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
        cos, sin = self.rotary_emb(q, seq_len=T)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)                                     # rotary encoding
        q = torch.cat((q, query_pass), dim=-1)
        k = torch.cat((k, key_pass), dim=-1)
        
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))                 # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
        att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf'))                     # causal mask
        att = F.softmax(att, dim = -1)                                                  # softmax

        x = att @ v                                                                     # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
        x = x.transpose(1, 2).contiguous().view(B, T, -1)                               # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)

        x = self.output(x)
        return x

class GeGLU(torch.nn.Module):
    def __init__(self, config, layer_id, time_shift = False):
        super().__init__()
        self.layer_id = layer_id

        if time_shift:
            self.time_shift = nn.ZeroPad2d((0,0,1,-1))

        hidden_sz = 3 * config.n_ffn
        self.key = nn.Linear(config.n_embd, hidden_sz)
        self.value = nn.Linear(config.n_embd, hidden_sz)
        self.weight = nn.Linear(hidden_sz, config.n_embd)

    def forward(self, x):
        B, T, C = x.size()
        if hasattr(self, 'time_shift'):
            x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
        
        k = self.key(x)
        v = self.value(x)        
        y = self.weight(F.gelu(k) * v)
        return y

########################################################################################################
# MHA_pro: with more tricks
########################################################################################################

class MHA_pro(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.layer_id = layer_id
        assert config.n_attn % config.n_head == 0
        self.n_head = config.n_head
        self.ctx_len = config.ctx_len
        self.head_size = config.n_attn // config.n_head

        self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_len))
        self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
        self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
        self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
        self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))

        self.time_shift = nn.ZeroPad2d((0,0,1,-1))
        self.query = nn.Linear(config.n_embd, config.n_attn)
        self.key = nn.Linear(config.n_embd, config.n_attn)
        self.value = nn.Linear(config.n_embd, config.n_attn)
        
        self.rotary_ndims = int(self.head_size * 0.5)
        self.rotary_emb = RotaryEmbedding(self.rotary_ndims)

        self.head_mix = nn.Conv2d(self.n_head, self.n_head, kernel_size=1, bias=False)  # talking heads

        self.output = nn.Linear(config.n_attn, config.n_embd)

    def forward(self, x):
        B, T, C = x.size()
        TT = self.ctx_len
        w = F.pad(self.time_w, (0, TT))
        w = torch.tile(w, [TT])
        w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
        w = w[:, :, TT-1:] # w is now a circulant matrix
        w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]

        x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)      # time-shift mixing
        q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)       # (B, T, C) -> (B, nh, T, hs)
        k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)         # (B, T, C) -> (B, nh, T, hs)
        v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)       # (B, T, C) -> (B, nh, T, hs)

        q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
        k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
        cos, sin = self.rotary_emb(q, seq_len=T)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)                                     # rotary encoding
        q = torch.cat((q, query_pass), dim=-1)
        k = torch.cat((k, key_pass), dim=-1)  
        
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))                 # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
        att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf'))                     # causal mask
        att = F.softmax(att, dim = -1)                                                  # softmax
        att = att * w                                                                   # time-weighting
        att = self.head_mix(att)                                                        # talking heads

        x = att @ v                                                                     # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
        x = x.transpose(1, 2).contiguous().view(B, T, -1)                               # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)

        x = self.output(x) * self.time_gamma[:T, :]
        return x

########################################################################################################
# The GPT Model with our blocks
########################################################################################################

class RMSNorm(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.dd = d ** (-1. / 2)
        self.weight = nn.Parameter(torch.ones(d))

    def forward(self, x):
        norm_x = x.norm(2, dim=-1, keepdim=True)
        x_normed = x / (norm_x * self.dd + 1e-12)
        return self.weight * x_normed

class FixedNorm(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.dd = d ** (-1. / 2)

    def forward(self, x):
        norm_x = x.norm(2, dim=-1, keepdim=True)
        x_normed = x / (norm_x * self.dd + 1e-12)
        return x_normed

########################################################################################################

class GPTConfig:
    def __init__(self, vocab_size, ctx_len, **kwargs):
        self.vocab_size = vocab_size
        self.ctx_len = ctx_len
        for k,v in kwargs.items():
            setattr(self, k, v)

class Block(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.config = config

        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)

        if config.model_type == 'RWKV':
            # self.ln1 = FixedNorm(config.n_embd)
            # self.ln2 = FixedNorm(config.n_embd)
            self.attn = RWKV_TimeMix(config, layer_id)
            self.mlp = RWKV_ChannelMix(config, layer_id)

        elif config.model_type == 'MHA_rotary':
            self.attn = MHA_rotary(config, layer_id)
            self.mlp = GeGLU(config, layer_id)
        
        elif config.model_type == 'MHA_shift':
            self.attn = MHA_rotary(config, layer_id, time_shift=True)
            self.mlp = GeGLU(config, layer_id, time_shift=True)
        
        elif config.model_type == 'MHA_pro':
            self.attn = MHA_pro(config, layer_id)
            self.mlp = RWKV_ChannelMix(config, layer_id)

    def forward(self, x):

        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        
        return x

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)

        self.blocks = nn.Sequential(*[Block(config, i) for i in range(config.n_layer)])

        self.ln_f = nn.LayerNorm(config.n_embd)
        self.time_out = nn.Parameter(torch.ones(1,config.ctx_len,1)) # reduce confidence of early tokens
        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.head_q = nn.Linear(config.n_embd, 256)
        self.head_q.scale_init = 0.01
        self.head_k = nn.Linear(config.n_embd, 256)
        self.head_k.scale_init = 0.01
        self.register_buffer("copy_mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))

        self.ctx_len = config.ctx_len

        if self.config.model_type == 'RWKV':
            RWKV_Init(self, config)
        else:
            self.apply(self._init_weights)

        logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))

    def get_ctx_len(self):
        return self.ctx_len

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.01)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def configure_optimizers(self, train_config):
        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()

        whitelist_weight_modules = (nn.Linear, )
        blacklist_weight_modules = (RMSNorm, nn.LayerNorm, nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn # full param name

                if pn.endswith('bias') or ('time' in fpn) or ('head' in fpn):
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    no_decay.add(fpn)

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params), )

        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
        return optimizer

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."

        x = self.tok_emb(idx)

        x = self.blocks(x)

        x = self.ln_f(x)

        q = self.head_q(x)[:,:T,:]
        k = self.head_k(x)[:,:T,:]
        c = (q @ k.transpose(-2, -1)) * (1.0 / 256)
        c = c.masked_fill(self.copy_mask[:T,:T] == 0, 0)
        c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float()

        x = x * self.time_out[:, :T, :] # reduce confidence of early tokens
        x = self.head(x) + c

        loss = None
        if targets is not None:
            loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))

        return x, loss

5. 总结

RWKV模型在保持Transformer模型高性能的同时,通过引入递归加权机制,有效地减少了计算复杂度和内存占用。该模型在处理长序列任务时表现优异,特别适用于资源有限的环境。随着模型的不断优化,RWKV有望成为未来高效序列处理任务中的重要工具。

6. 参考文献

1. Vaswani, A., et al. (2017). Attention is All You Need. Advances in Neural Information Processing Systems.
2. Hochreiter, S., & Schmidhuber, J. (1997). Long Short-Term Memory. Neural Computation, 9(8), 1735-1780.
3. Peng, B., et al. (2023). RWKV: Reinventing RNNs for the Transformer Era. arXiv preprint arXiv:2305.13048v2.
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值