Link Address: Inference with Gemma using JAX and Flax
Overview
Gemma is a family of lightweight, state-of-the-art open large language models, based on the Google DeepMind Gemini research and technology. This tutorial demonstrates how to perform basic sampling/inference with the Gemma 2B Instruct model using Google DeepMind's gemma library that was written with JAX (a high-performance numerical computing library), Flax (the JAX-based neural network library), Orbax (a JAX-based library for training utilities like checkpointing), and SentencePiece (a tokenizer/detokenizer library). Although Flax is not used directly in this notebook, Flax was used to create Gemma.
This notebook can run on Google Colab with free T4 GPU (go to Edit > Notebook settings > Under Hardware accelerator select T4 GPU).


Setup
1. Set up Kaggle access for Gemma
To complete this tutorial, you first need to follow the setup instructions at Gemma setup, which show you how to do the following:
- Get access to Gemma on kaggle.com.
- Select a Colab runtime with sufficient resources to run the Gemma model.
- Generate and configure a Kaggle username and API key.
After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.
1. Get access to Gemma on kaggle.com.


Accept之后会切换到另一个page页面。如下所示:

2.Select a Colab runtime with sufficient resources to run the Gemma model.


3.Setting hardware Accelerator: GPU T4 x 2

4. Generate and configure a Kaggle username and API key


5. How to obtain Kaggle's Account and API KEY:
- 登录到Kaggle:如果你还没有Kaggle账号,首先需要注册一个。然后登录到你的Kaggle账号。
- 转到账号设置:登录后,点击右上角的用户图标,然后选择“账号”选项。这会带你到你的账号设置页面。
- 创建API密钥:在账号设置页面的左侧菜单中,选择“API”选项。然后,你可以看到一个按钮标有“Create New API Token”。点击这个按钮。
- 下载API密钥:点击“Create New API Token”按钮后,Kaggle会生成一个JSON文件,其中包含你的用户名和API密钥。保存这个文件在一个安全的地方,因为这是你与Kaggle API交互所需要的凭证。
- 设置Kaggle API:将下载的JSON文件中的用户名和API密钥信息配置到你的机器上。你可以将这些信息保存到
~/.kaggle/kaggle.json文件中,这是Kaggle命令行工具(kaggle)默认使用的位置。确保不要泄露你的API密钥信息。
Set environment variables
step1: environment variables settings
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
for filename in filenames:
print(os.path.join(dirname, filename))
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

step2: Set environment variables for KAGGLE_USERNAME and KAGGLE_KEY. When prompted with the "Grant access?" messages, agree to provide secret access.
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("KAGGLE_KEY")
secret_value_1 = user_secrets.get_secret("KAGGLE_USERNAME")

Step3:This notebook focuses on using a free Colab GPU. To enable hardware acceleration, click on Edit > Notebook settings > Select T4 GPU > Save.
Next, you need to install the Google DeepMind gemma library from github.com/google-deepmind/gemma. If you get an error about "pip's dependency resolver", you can usually ignore it.
Note: By installing gemma, you will also install flax, core jax, optax (the JAX-based gradient processing and optimization library), orbax, and sentencepiece.
pip install -q git+https://github.com/google-deepmind/gemma.git

Load and prepare the Gemma model
- Load the Gemma model with kagglehub.model_download, which takes three arguments:
handle: The model handle from Kagglepath: (Optional string) The local pathforce_download: (Optional boolean) Forces to re-download the model
Note: Be mindful that the gemma-2b-it model is around 3.7Gb in size.
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}

import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')

print('GEMMA_PATH:', GEMMA_PATH)

GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2
2.Check the location of the model weights and the tokenizer, then set the path variables. The tokenizer directory will be in the main directory where you downloaded the model, while the model weights will be in a sub-directory. For example:
- The
tokenizer.modelfile will be in/LOCAL/PATH/TO/gemma/flax/2b-it/2). - The model checkpoint will be in
/LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it)
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)

Perform sampling/inference
- Load and format the Gemma model checkpoint with the gemma.params.load_and_format_params method:
from gemma import params as params_lib
params = params_lib.load_and_format_params(CKPT_PATH)

2. Load the Gemma tokenizer, constructed using sentencepiece.SentencePieceProcessor:
import sentencepiece as spm
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)

- To automatically load the correct configuration from the Gemma model checkpoint, use gemma.transformer.TransformerConfig. The
cache_sizeargument is the number of time steps in the GemmaTransformercache. Afterwards, instantiate the Gemma model astransformerwith gemma.transformer.Transformer (which inherits from flax.linen.Module).
Note: The vocabulary size is smaller than the number of input embeddings because of unused tokens in the current Gemma release.
from gemma import transformer as transformer_lib
transformer_config = transformer_lib.TransformerConfig.from_params(
params=params,
cache_size=1024
)
transformer = transformer_lib.Transformer(transformer_config)

- Create a
samplerwith gemma.sampler.Sampler on top of the Gemma model checkpoint/weights and the tokenizer:
from gemma import sampler as sampler_lib
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer'],
)

- Write a prompt in
input_batchand perform inference. You can tweaktotal_generation_steps(the number of steps performed when generating a response — this example uses100to preserve host memory).
Note: If you run out of memory, click on Runtime > Disconnect and delete runtime, and then Runtime > Run all.
prompt = [
"\n# What is the meaning of life?",
]
reply = sampler(input_strings=prompt,
total_generation_steps=100,
)
for input_string, out_string in zip(prompt, reply.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")

- (Optional) Run this cell to free up memory if you have completed the notebook and want to try another prompt. Afterwards, you can instantiate the
sampleragain in step 3 and customize and run the prompt in step 4.
del sampler
本文介绍如何使用GoogleDeepMind的Gemma库,基于JAX、Flax、Orbax和SentencePiece进行Gemma2BInstruct模型的样本生成和推断。教程包括设置Kaggle访问权限,获取GPU资源,以及在GoogleColab中实际操作的步骤。

被折叠的 条评论
为什么被折叠?



