Mamba-minimal Mamba的最小限度实现 (二)


manba的简单最小限度实现,和原始论文实现 state-spaces/mamba (github.com)相比,为了可读性对参数没有很好的初始化,原论文用CUDA写了并行扫描,所以速度会快。
这里是剩余部分介绍,主要包括利用MambaBlock和其他组件如残差连接,归一化等定义一个序列模型。

MambaBlock的介绍Mamba-minimal Mamba的最小限度实现 (一)-CSDN博客

链接

来自johnma2006/mamba-minimal: Simple, minimal implementation of the Mamba SSM in one file of PyTorch. (github.com)

导入所需包

from __future__ import annotations
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat, einsum

class ModelArgs

模型参数设置

参数介绍
d_model模型维度,和输入数据通道对应
n_layer残差块的数目
d_state潜在状态维度
expand扩展因子,d_in = d_state * state
dt_rankdelta的秩
d_conv1D卷积的卷积核大小
vocab_size词汇表的大小
pad_vocab_size_multiple确保vocab_size是设定值的倍数
conv_bias1D卷积的bias选项
biaslm_head映射的bias选项
@dataclass
class ModelArgs:
    d_model: int
    n_layer: int
    vocab_size: int
    d_state: int = 16
    expand: int = 2
    dt_rank: Union[int, str] = 'auto'
    d_conv: int = 4 
    pad_vocab_size_multiple: int = 8
    conv_bias: bool = True
    bias: bool = False
   
    
    def __post_init__(self):
        self.d_inner = int(self.expand * self.d_model)
        
        if self.dt_rank == 'auto':
            self.dt_rank = math.ceil(self.d_model / 16)
            
        if self.vocab_size % self.pad_vocab_size_multiple != 0:
            self.vocab_size += (self.pad_vocab_size_multiple
                                - self.vocab_size % self.pad_vocab_size_multiple)

class Mamba

一个完整的序列处理Mamba模型,包含多个被包裹的MambaBlock。

nn.Embedding参照深度学习:pytorch nn.Embedding详解-CSDN博客

lm_head层则是预测下一个token的输出层,它将模型的输出映射到一个概率分布上,以便于模型预测下一个token,权重和Embedding公用。

输入一个序列 x ( b a t c h _ s i z e , l e n g t h ) x(batch\_size, length) x(batch_size,length) 简写为 ( b , l ) (b, l) (b,l),输出取词的概率 ( b , l , v o c a b _ s i z e ) (b, l, vocab\_size) (b,l,vocab_size)

组件尺寸变换
embedding(b, l) -> (b, l, d_model)
layers(b, l, d_model) -> (b, l, d_model)
norm_f\
lm_head(b, l, d_model) -> (b, l, vocab_size)

def __ init __

class Mamba(nn.Module):
    def __init__(self, args: ModelArgs):
        """Full Mamba model."""
        super().__init__()
        self.args = args
        
        self.embedding = nn.Embedding(args.vocab_size, args.d_model)
        self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
        self.norm_f = RMSNorm(args.d_model)

        self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
        self.lm_head.weight = self.embedding.weight  # Tie output projection to embedding weights.
                                                     # See "Weight Tying" paper

def forward

 def forward(self, input_ids):
      
        x = self.embedding(input_ids)
        
        for layer in self.layers:
            x = layer(x)
            
        x = self.norm_f(x)
        logits = self.lm_head(x)
        return logits

class ResidualBlock

一个包裹MambaBlock的一个残差块

MambaBlock的介绍Mamba-minimal Mamba的最小限度实现 (一)-CSDN博客

class ResidualBlock(nn.Module):
    def __init__(self, args: ModelArgs):

        super().__init__()
        self.args = args
        self.mixer = MambaBlock(args)
        self.norm = RMSNorm(args.d_model)
        

    def forward(self, x):
   
        output = self.mixer(self.norm(x)) + x

        return output
            

class RNSNorm

所用到的归一化

可以参考RMSNorm论文阅读-CSDN博客

LLM中的RMSNorm - 知乎 (zhihu.com)

class RMSNorm(nn.Module):
    def __init__(self,
                 d_model: int,
                 eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))


    def forward(self, x):
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

        return output

文本生成demo

来自demo.ipynb

这里是一个colab_demo

加载模型

from model import Mamba, ModelArgs
from transformers import AutoTokenizer

# One of:
#     'state-spaces/mamba-2.8b-slimpj'
#     'state-spaces/mamba-2.8b'
#     'state-spaces/mamba-1.4b'
#     'state-spaces/mamba-790m'
#     'state-spaces/mamba-370m'
#     'state-spaces/mamba-130m'
pretrained_model_name = 'state-spaces/mamba-370m'

model = Mamba.from_pretrained(pretrained_model_name)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

生成文本
在概率为top-k的输出中采样

import torch
import torch.nn.functional as F


def generate(model,
             tokenizer,
             prompt: str,
             n_tokens_to_gen: int = 50,
             sample: bool = True,
             top_k: int = 40):
    model.eval()
    
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids
    
    for token_n in range(n_tokens_to_gen):
        with torch.no_grad():
            indices_to_input = input_ids
            next_token_logits = model(indices_to_input)[:, -1]
        
        probs = F.softmax(next_token_logits, dim=-1)
        (batch, vocab_size) = probs.shape
        
        if top_k is not None:
            (values, indices) = torch.topk(probs, k=top_k)
            probs[probs < values[:, -1, None]] = 0
            probs = probs / probs.sum(axis=1, keepdims=True)
        
        if sample:
            next_indices = torch.multinomial(probs, num_samples=1)
        else:
            next_indices = torch.argmax(probs, dim=-1)[:, None]
        
        input_ids = torch.cat([input_ids, next_indices], dim=1)

    output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]
    
    return output_completions

在这里插入图片描述

print(generate(model, tokenizer, 'Mamba is the'))
  • 25
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值