Inference with Gemma using JAX and Flax之实践项目

本文介绍如何使用GoogleDeepMind的Gemma库,基于JAX、Flax、Orbax和SentencePiece进行Gemma2BInstruct模型的样本生成和推断。教程包括设置Kaggle访问权限,获取GPU资源,以及在GoogleColab中实际操作的步骤。
摘要由CSDN通过智能技术生成

Link Address: Inference with Gemma using JAX and Flax

Overview

Gemma is a family of lightweight, state-of-the-art open large language models, based on the Google DeepMind Gemini research and technology. This tutorial demonstrates how to perform basic sampling/inference with the Gemma 2B Instruct model using Google DeepMind's gemma library that was written with JAX (a high-performance numerical computing library), Flax (the JAX-based neural network library), Orbax (a JAX-based library for training utilities like checkpointing), and SentencePiece (a tokenizer/detokenizer library). Although Flax is not used directly in this notebook, Flax was used to create Gemma.

This notebook can run on Google Colab with free T4 GPU (go to Edit > Notebook settings > Under Hardware accelerator select T4 GPU).

Run in Google Colab

Setup

1. Set up Kaggle access for Gemma

To complete this tutorial, you first need to follow the setup instructions at Gemma setup, which show you how to do the following:

  • Get access to Gemma on kaggle.com.
  • Select a Colab runtime with sufficient resources to run the Gemma model.
  • Generate and configure a Kaggle username and API key.

After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.

 

1. Get access to Gemma on kaggle.com.

Accept之后会切换到另一个page页面。如下所示:

Gemma | Kaggle

2.Select a Colab runtime with sufficient resources to run the Gemma model. 

 3.Setting hardware Accelerator: GPU T4 x 2

4. Generate and configure a Kaggle username and API key

5. How to obtain Kaggle's Account and API KEY:

  1. 登录到Kaggle:如果你还没有Kaggle账号,首先需要注册一个。然后登录到你的Kaggle账号。
  2. 转到账号设置:登录后,点击右上角的用户图标,然后选择“账号”选项。这会带你到你的账号设置页面。
  3. 创建API密钥:在账号设置页面的左侧菜单中,选择“API”选项。然后,你可以看到一个按钮标有“Create New API Token”。点击这个按钮。
  4. 下载API密钥:点击“Create New API Token”按钮后,Kaggle会生成一个JSON文件,其中包含你的用户名和API密钥。保存这个文件在一个安全的地方,因为这是你与Kaggle API交互所需要的凭证。
  5. 设置Kaggle API:将下载的JSON文件中的用户名和API密钥信息配置到你的机器上。你可以将这些信息保存到~/.kaggle/kaggle.json文件中,这是Kaggle命令行工具(kaggle)默认使用的位置。确保不要泄露你的API密钥信息。

Set environment variables

step1: environment variables settings

# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

step2: Set environment variables for KAGGLE_USERNAME and KAGGLE_KEY. When prompted with the "Grant access?" messages, agree to provide secret access.

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("KAGGLE_KEY")
secret_value_1 = user_secrets.get_secret("KAGGLE_USERNAME")

Step3:This notebook focuses on using a free Colab GPU. To enable hardware acceleration, click on Edit > Notebook settings > Select T4 GPU > Save.

Next, you need to install the Google DeepMind gemma library from github.com/google-deepmind/gemma. If you get an error about "pip's dependency resolver", you can usually ignore it.

Note: By installing gemma, you will also install flax, core jaxoptax (the JAX-based gradient processing and optimization library), orbax, and sentencepiece.

pip install -q git+https://github.com/google-deepmind/gemma.git

Load and prepare the Gemma model

  1. Load the Gemma model with kagglehub.model_download, which takes three arguments:
  • handle: The model handle from Kaggle
  • path: (Optional string) The local path
  • force_download: (Optional boolean) Forces to re-download the model

Note: Be mindful that the gemma-2b-it model is around 3.7Gb in size.

GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}

import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')

print('GEMMA_PATH:', GEMMA_PATH)

GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2

2.Check the location of the model weights and the tokenizer, then set the path variables. The tokenizer directory will be in the main directory where you downloaded the model, while the model weights will be in a sub-directory. For example:

  • The tokenizer.model file will be in /LOCAL/PATH/TO/gemma/flax/2b-it/2).
  • The model checkpoint will be in /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it)
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)

 

Perform sampling/inference

  1. Load and format the Gemma model checkpoint with the gemma.params.load_and_format_params method:
from gemma import params as params_lib

params = params_lib.load_and_format_params(CKPT_PATH)

 

   2. Load the Gemma tokenizer, constructed using sentencepiece.SentencePieceProcessor:

import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)

  1. To automatically load the correct configuration from the Gemma model checkpoint, use gemma.transformer.TransformerConfig. The cache_size argument is the number of time steps in the Gemma Transformer cache. Afterwards, instantiate the Gemma model as transformer with gemma.transformer.Transformer (which inherits from flax.linen.Module).

Note: The vocabulary size is smaller than the number of input embeddings because of unused tokens in the current Gemma release.

from gemma import transformer as transformer_lib

transformer_config = transformer_lib.TransformerConfig.from_params(
    params=params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(transformer_config)

  1. Create a sampler with gemma.sampler.Sampler on top of the Gemma model checkpoint/weights and the tokenizer:
from gemma import sampler as sampler_lib

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)

  1. Write a prompt in input_batch and perform inference. You can tweak total_generation_steps (the number of steps performed when generating a response — this example uses 100 to preserve host memory).

Note: If you run out of memory, click on Runtime > Disconnect and delete runtime, and then Runtime > Run all.

prompt = [
    "\n# What is the meaning of life?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=100,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")

 

  1. (Optional) Run this cell to free up memory if you have completed the notebook and want to try another prompt. Afterwards, you can instantiate the sampler again in step 3 and customize and run the prompt in step 4.
del sampler

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值