Pytorch框架下使用Gemma

Gemma介绍

Gemma是谷歌发布的一款开源大语言模型,并且Gemma是开源大模型的SOTA,超越了Meta的LLaMa2。综合来说,Gemma有以下几个特点:

坚实的模型基础:Gemma使用与谷歌更强大的人工智能模型Gemini相同的研究和技术。这一共同的基础确保了Gemma建立在一个强大的基础上,并具有强大的能力潜力。

轻便易用:与体型较大的Gemini不同,Gemma的设计重量轻,所需资源较少。这使得它可以被更广泛的用户访问,包括研究人员、开发人员,甚至那些计算资源有限的用户。

开放定制:Gemma模型可以进行微调,这意味着它们可以针对特定任务或应用程序进行调整和定制。这允许用户根据自己的特定需求定制模型。

“Gemma”这个名字来源于拉丁语中“宝石”的意思,反映了谷歌将这些模型视为推进人工智能研发的宝贵工具。总的来说,谷歌的Gemma有望实现强大人工智能工具的公众化,并为人工智能开发创造一个更具包容性和协作性的环境。

所以Google将Gemma放到了Github上:google/gemma_pytorch: The official PyTorch implementation of Google's Gemma models (github.com)

这里使用Kaggle这个数据科学竞赛网站提供的环境。

Kaggle申请Gemma

首先你得有个Kaggle账号,这个简单注册即可,就是有时候需要科学上网才能注册成功。

接着打开Gemma的模型页面申请:Gemma | Kaggle

申请过程就是很简单的填写同意书并接受条款和条件

最后选择Pytorch版本,点击new notebook就可以开始使用模型了。

基本使用代码讲解

Gemma模型页面中给出了Pytorch版本使用Gemma的基本代码,只需要复制进新建的notebook就可以运行。

运行示例代码会得到一个结果:
 

'What is a popular area in California? €)\n tanong:\nSneakyThrows model\nSneakyThrows user\nSneakyThrows user (more than two lines)SneakyThrows modelSneakyThrows userSneakyThrows modelSneakyThrows modelSneakyThrows userSneakyThrows modelSneakyThrows modelSneakyThrows\nSneakyThrows answerSneakyThrows userSneakyThrows modelSneakyThrows modelSneakyThrows modelSneakyThrows userSneakyThrows modelSneakyThrows userSneakyThrows modelSneakyThrows\nSneakyThrows userSneakyThrows modelSneakyThrows modelSneakyThrowsSneakyThrows modelSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrows'

就是除了第一句话,其他根本毫无逻辑。我们尝试解读一下代码,去了解为什么会生成一堆这种东西。

配置环境

# Setup the environment

# 安装必要的库
!pip install -q -U immutabledict sentencepiece

# 克隆 gemma_pytorch 代码库
!git clone https://github.com/google/gemma_pytorch.git

# 将 gemma_pytorch/gemma 目录下的所有文件移动到当前目录
!mv /kaggle/working/gemma_pytorch/gemma/* /kaggle/working/gemma/

安装的两个陌生的库——immutabledict 和 sentencepiece。

1. ImmutableDict

简介:

immutabledict 是一个 Python 库,提供不可变字典类型。与普通的字典不同,不可变字典一旦创建后就不能被修改。这使得不可变字典在多线程环境下更加安全,并且可以提高代码的性能。

主要功能:

  • 创建不可变字典
  • 访问不可变字典中的元素
  • 遍历不可变字典
  • 检查不可变字典中是否存在某个键
  • 将不可变字典转换为其他类型

使用场景:

  • 多线程环境
  • 需要提高代码性能的场景
  • 需要保证数据一致性的场景

2. SentencePiece

简介:

sentencepiece 是一个 Python 库,提供文本分词功能。它可以将文本分割成单个字符、子词或词语,以便于后续的处理。

主要功能:

  • 文本分词
  • 词汇表生成
  • 模型训练
  • 模型预测

使用场景:

  • 机器翻译
  • 文本摘要
  • 信息抽取
  • 问答系统

immutabledictsentencepiece 都是非常有用的 Python 库。immutabledict 可以提高代码的安全性 and 性能,sentencepiece 可以提高文本处理的效率。

加载模型

# 导入必要的类库
import sys  # 系统模块
sys.path.append("/kaggle/working/gemma_pytorch/")  # 添加自定义库路径
from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b  # 加载配置类
from gemma.model import GemmaForCausalLM  # 加载模型类
from gemma.tokenizer import Tokenizer  # 加载分词器类
import contextlib  # 上下文管理工具
import os  # 操作系统模块
import torch  # 深度学习框架

# 设置模型类型和设备
VARIANT = "2b"  # 模型类型(2b 或 7b)
MACHINE_TYPE = "cpu"  # 运行设备(cpu 或 cuda)
weights_dir = '/kaggle/input/gemma/pytorch/2b/2'  # 模型权重所在目录

# 定义上下文管理器,设置默认张量类型
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """
    设置默认 torch dtype 为指定值,并在上下文结束后恢复为 float。
    """
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

# 加载模型配置
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")  # 设置分词器路径

# 设置设备
device = torch.device(MACHINE_TYPE)

# 使用上下文管理器设置默认张量类型,并加载模型
with _set_default_tensor_type(model_config.get_dtype()):
    model = GemmaForCausalLM(model_config)  # 创建模型对象
    ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')  # 获取模型权重路径
    model.load_weights(ckpt_path)  # 加载模型权重
    model = model.to(device).eval()  # 将模型移动到指定设备并设置为评估模式

这段代码首先导入必要的类库,然后设置模型类型、设备和权重目录。接着定义了一个上下文管理器,用于设置默认张量类型。最后,加载模型配置,创建模型对象,加载模型权重,并将其移动到指定设备并设置为评估模式。

使用模型

# 用户聊天模板
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n"

# 解释:
# - `<start_of_turn>` 表示对话开始
# - `user` 表示用户
# - `{prompt}` 表示用户输入
# - `<end_of_turn>` 表示对话结束

# 模型聊天模板
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn>\n"

# 解释:
# - `<start_of_turn>` 表示对话开始
# - `model` 表示模型
# - `{prompt}` 表示模型输出
# - `<end_of_turn>` 表示对话结束

# 生成对话
prompt = (
    USER_CHAT_TEMPLATE.format(prompt="中国哪个大学最好?")
    + MODEL_CHAT_TEMPLATE.format(prompt="北京大学")
    + USER_CHAT_TEMPLATE.format(prompt="北京大学在QS的排名是多少")
    + "<start_of_turn>model\n"
)

# 解释:
# - `prompt` 变量包含了完整的对话文本,包括用户输入和模型输出

# 生成模型回复
model.generate(
    prompt,
    device=device,
    output_len=100,
)

# 解释:
# - `model.generate` 函数用于生成模型回复
# - `prompt` 是模型输入
# - `device` 是模型运行设备
# - `output_len` 是模型输出的最大长度

这里我对原代码做出了改变,即最后的生成的输入使用prompt代替

USER_CHAT_TEMPLATE.format(prompt=prompt)

这是因为我觉得prompt已经遵循user和model一问一答的格式,并且最后并没有加end_of_turn这样符合模型继续生成下面文字的逻辑,而如果在格式化原来的prompt则不合逻辑。

另一个改变是我将prompt内容做了改变,测试模型对中文问题的回答怎么样。

当然结果比较惨:

"2005-2006内考满170分以上,2007-2008年170分以上,2009年1分子到160以上就真的不会考啦 prospetvi\n conquête+ de l'argent\n北京大学在QS的排名是209! RequiresApi to the moon! (๑•̀ㅂ•́)و✧\n conquête+ de l'argent\n北京大学在QS"

生成的答案更为离谱,能从生成的文字中看出答案,不过堂堂北大竟被认为QS排名209的水校。

  • 26
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值