本地运行Gemma的pytorch集成

本文介绍了Google发布的Gemma大模型,特别是2B-it版本的部署方法,包括安装Python环境、依赖库、连接Kaggle下载模型,以及如何在本地使用GPU进行会话生成。
摘要由CSDN通过智能技术生成

Gemma是Google在2024年2月21日发布的一款轻量的开源大模型,采用了和Google Gemini模型一样的技术。有猜测Google在毫无预告的情况下急忙发布Gemma是对Meta的Llama3的截胡,但不管怎么说作为名厂名牌的大模型,自然要上手尝试尝试。

这次发布的Gemma有2B参数和7B参数两个版本,两个版本又分别提供了预训练(Pretrained)和指令调试(Instruction tuned)两个版本。预训练版本做了基础训练,而指令调试版本做了根据人类语言交互的特定训练调整,所以如果直接拿来做会话使用可以下载it版本。2B和7B在于参数量的多少,7B需要更多的资源去运行。

好了,前面啰嗦了一堆背景,为了引出这里介绍2b-it版本地部署的原因——耗资源少且可以本地使用会话。

准备环境

  • 安装python venv,命名gemma-torch
conda env create -n "gemma-torch"
  • 激活虚拟环境
conda activate gemma-torch
  • 安装依赖的库
pip install torch immutabledict sentencepiece numpy packaging

 后面两个库不是官方文档里要求的,但是根据我执行报错,需要安装。另外上面命令也取消了-q -U简单粗暴也方便观察。

为了后续用代码连接kaggle下载模型,还需要安装kagglehub包:

pip install kagglehub

连接kaggle

这一步的目的是从kaggle上面下载模型。

  • 首先获取kaggle的访问权限

登录kaggle,在设置页面(https://www.kaggle.com/settings)的API一节点击“Create New Token”,会触发下载kaggle.json。

  • 配置环境

将kaggle.json文件拷贝到~/.kaggle/目录下。并在~/.bash_profile中设置环境变量KAGGLE_CONFIG_DIR为~/.kaggle。

这样就可以通过下面代码访问(后面的代码写到一块,不需要此处执行)。

import kagglehub

kagglehub.login()

运行代码

经过前面的配置后,可以代码本地运行2b-it模型了。不过加载模型还需要gemma_pytorch包。

从github仓库clone到本地:

# NOTE: The "installation" is just cloning the repo.
git clone https://github.com/google/gemma_pytorch.git

将下载好的gemma_pytorch文件夹放到下面脚本文件同一级目录下 ,并在~/.bash_profile中设置PYTHONPATH环境变量包含该文件夹路径。

最后运行脚本(gemma_torch.py):

# Choose variant and machine type
import kagglehub
import os
import sys
from gemma_pytorch.gemma.config import get_config_for_7b, get_config_for_2b
from gemma_pytorch.gemma.model import GemmaForCausalLM
import torch


VARIANT = '2b-it'
#如果是cpu运行,将下面cuda改成cpu,不过巨慢
MACHINE_TYPE = 'cuda'

# Load model weights
# 模型下载到了~/.cache目录下
weights_dir = kagglehub.model_download(f'google/gemma/pyTorch/{VARIANT}')

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

# Set up model config.
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n'
MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n'

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=100,
)

# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=60,
)

一点后话

能用GPU还是上GPU吧,我本地用的CPU笔记本跑的巨慢。

可以在线使用colab,具体步骤参考这个帖子(昨天Google发布了最新的开源模型Gemma,今天我来体验一下_gemma_lm.generate-CSDN博客)。

不过我在使用过程中发现T4经常在预测执行时报OOM,导致无法产出结果。

参考资料:

pytorch中使用Gemma: https://ai.google.dev/gemma/docs/pytorch_gemma

官方文档地址:https://ai.google.dev/gemma/docs 

  • 21
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值