LLMs之Gemma:gsm8k_eval.ipynb文件解读——通过构建基于问题-答案对的 prompting 模式来评估Gemma模型在GSM8K数据集上的表现水平

本文介绍了如何使用Gemma模型对GSM8K数据集进行评估,包括安装与下载模型,加载数据集,定义辅助函数以构建prompting输入,以及通过循环评估模型在测试集上的性能,重点关注了GSM8K的数据特性对小型模型评估的意义。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

LLMs之Gemma:gsm8k_eval.ipynb文件解读——通过构建基于问题-答案对的 prompting 模式来评估Gemma模型在GSM8K数据集上的表现水平

目录

核心步骤

使用Gemma对GSM8K进行评估

一、安装与下载

下载模型checkpoint

二、加载模型、数据集进行模型评估

加载GSM8K数据集

# @title Testing library

# @title GSM8K Prompts

加载语言模型、加载分词器

构建采样器

通过一个循环来评估模型在GSM8K测试集上的性能


核心步骤

>> 通过Kaggle下载Gemma模型的权重文件,作为本次评估的基础模型。

>> 加载和预处理GSM8K数据集,将训练集和测试集分离。

>> 定义一些辅助函数,如找出字符串中的数字,提取答案等。

>> 构建 prompting 输入,将 GSM8K 问题加入到 prompting 中。

>> 使用 sampler 对每个测试样本问题生成响应,并提取答案进行判断。

>> 记录模型在每个测试样本上的预测情况,最后计算准确率。

使用Gemma对GSM8K进行评估

GSM8K数据集为小型模型提供了良好的评估挑战,原因如下:

>> 概念简单性:尽管GSM8K中的问题需要多步骤推理,但它们主要涉及基本的数学概念和基本算术运算。这使得数据集能够被小型模型所使用,这些模型可能没有能力处理复杂的数学推理。

>> 语言多样性:GSM8K强调语言多样性,确保问题不仅仅是同一模板的变化。这迫使模型推广他们对语言和数学概念的理解,而不是依赖于表面的模式匹配。

>> 适中的难度:GSM8K中的问题难度足以测试小型模型的极限,而不会完全无法处理。这允许在合理的难度范围内对不同的模型和方法进行有意义的评估和比较。

>> 自然语言解决方案:GSM8K提供自然语言解决方案,鼓励模型发展口语分析技能并产生人类可解释的推理步骤。这对于可能难以处理纯符号或基于方程的解决方案的小型模型尤其重要。

通过关注小学数学概念并强调语言多样性,GSM8K为评估小型语言模型的非形式推理能力提供了一个有价值的基准,并确定了改进领域。

2B Gemma检查点达到了19%的分数,这个结果高于使用更大竞争检查点获得的结果。

一、安装与下载

! pip install git+https://github.com/google-deepmind/gemma.git
! pip install --user kaggle

下载模型checkpoint

 "要使用Gemma的检查点,您需要一个Kaggle帐户和API密钥。以下是获取它们的方法:

访问 https://www.kaggle.com/ 并创建一个帐户。 进入您的帐户设置,然后是’API’部分。 点击’创建新令牌’以下载您的密钥。 然后运行下面的单元格。

import kagglehub
kagglehub.login()
如果一切顺利,您应该看到:
Kaggle credentials set.
Kaggle credentials successfully validated.

二、加载模型、数据集进行模型评估

 现在选择并下载您想要尝试的检查点。请注意,您需要A100运行时来运行7b模型。


import os

VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')

ckpt_path = os.path.join(weights_dir, variant)
vocab_path = os.path.join(weights_dir, 'tokenizer.model')
# @title Python imports
import re
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib

import datasets
import sentencepiece as spm

加载GSM8K数据集

gsm8k = datasets.load_dataset("gsm8k", "main", cache_dir='/tmp')
gsm8k_train, gsm8k_test = gsm8k['train'], gsm8k['test']

# @title Testing library


def find_numbers(x: str) -> list[str]:
  """Finds all numbers in a string."""
  # Search for number, possibly negative (hyphen), with thousand separators
  # (comma), and with a decimal point (period inbetween digits).
  numbers = re.compile(
      r'-?[\d,]*\.?\d+',
      re.MULTILINE | re.DOTALL | re.IGNORECASE,
  ).findall(x)
  return numbers


def find_number(x: str,
                answer_delimiter: str = 'The answer is') -> str:
  """Finds the most relevant number in a string."""
  # If model uses the answer delimiter, then select the first number following
  # that format.
  if answer_delimiter in x:
    answer = x.split(answer_delimiter)[-1]
    numbers = find_numbers(answer)
    if numbers:
      return numbers[0]

  # In general, select the last number in the string.
  numbers = find_numbers(x)
  if numbers:
    return numbers[-1]
  return ''


def maybe_remove_comma(x: str) -> str:
  # Example: 5,600 -> 5600
  return x.replace(',', '')

# @title GSM8K Prompts

PREAMBLE = “”“作为一个专家级问题解决者,逐步解决以下数学问题。”“”

来自CoT论文的默认gsm8k提示

https://arxiv.org/pdf/2201.11903.pdf 第35页。

PREAMBLE = “”“作为一个专家级问题解决者,逐步解决以下数学问题。”“”
来自CoT论文的默认gsm8k提示
https://arxiv.org/pdf/2201.11903.pdf 第35页。


PREAMBLE = """As an expert problem solver solve step by step the following mathematical questions."""

# The default gsm8k prompt from the CoT paper
# https://arxiv.org/pdf/2201.11903.pdf page 35.

PROMPT = """Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees. The answer is 6.

Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are 3 cars in the parking lot already. 2 more arrive. Now there are 3 + 2 = 5 cars. The answer is 5.

Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?
A: Leah had 32 chocolates and Leah's sister had 42. That means there were originally 32 + 42 = 74 chocolates. 35 have been eaten. So in total they still have 74 - 35 = 39 chocolates. The answer is 39.

Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?
A: Jason had 20 lollipops. Since he only has 12 now, he must have given the rest to Denny. The number of lollipops he has given to Denny must have been 20 - 12 = 8 lollipops. The answer is 8.

Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?
A: He has 5 toys. He got 2 from mom, so after that he has 5 + 2 = 7 toys. Then he got 2 more from dad, so in total he has 7 + 2 = 9 toys. The answer is 9.

Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?
A: There are 4 days from monday to thursday. 5 computers were added each day. That means in total 4 * 5 = 20 computers were added. There were 9 computers in the beginning, so now there are 9 + 20 = 29 computers. The answer is 29.

Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?
A: Michael initially had 58 balls. He lost 23 on Tuesday, so after that he has 58 - 23 = 35 balls. On Wednesday he lost 2 more so now he has 35 - 2 = 33 balls. The answer is 33.

Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
A: She bought 5 bagels for $3 each. This means she spent 5 * $3 = $15 on the bagels. She had $23 in beginning, so now she has $23 - $15 = $8. The answer is 8."""


# Extension of the default 8-shot prompt, page 35 in
# https://arxiv.org/pdf/2201.11903.pdf
# The extension is intended to improve performance on
# more complicated gsm8k examples.

EXTRA_3_SHOTS = """As an expert problem solver solve step by step the following mathematical questions.

Q: Tina makes $18.00 an hour.  If she works more than 8 hours per shift, she is eligible for overtime, which is paid by your hourly wage + 1/2 your hourly wage.  If she works 10 hours every day for 5 days, how much money does she make?
A: Here's how to calculate Tina's earnings:

**Regular Time:**
- Hours per shift: 8 hours
- Wage per hour: $18.00
- Regular pay per shift: 8 hours * $18.00/hour = $144.00

**Overtime:**
- Overtime hours per shift: 10 hours - 8 hours = 2 hours
- Overtime pay per hour: $18.00 + ($18.00 / 2) = $27.00
- Overtime pay per shift: 2 hours * $27.00/hour = $54.00

**Total per day:**
- Regular pay + overtime pay: $144.00/shift + $54.00/shift = $198.00/day

**Total for 5 days:**
- 5 days * $198.00/day = $990.00

**Therefore, Tina will make $990.00 in 5 days.** The answer is 990.

Q: Abigail is trying a new recipe for a cold drink. It uses 1/4 of a cup of iced tea and 1 and 1/4 of a cup of lemonade to make one drink. If she fills a pitcher with 18 total cups of this drink, how many cups of lemonade are in the pitcher?
A: ## Ambiguity in the Problem Statement:

There is one main ambiguity in the problem statement:

**Total volume vs. Number of servings:** The statement "18 total cups of this drink" could be interpreted in two ways:
  * **18 cups of the combined volume:** This would mean Abigail used a total of 18 cups of liquid, including both iced tea and lemonade.
  * **18 individual servings:** This would mean Abigail made 18 individual drinks, each containing 1/4 cup of iced tea and 1 1/4 cup of lemonade.

Let us assume the interpretation "18 cups of the combined volume".

## Solution assuming 18 cups of combined volume:

**Step 1: Find the proportion of lemonade in one drink:**

* Lemonade: 1 1/4 cups
* Iced tea: 1/4 cup
* Total: 1 1/4 + 1/4 = 1 1/2 cups
* Lemonade proportion: (1 1/4) / (1 1/2) = 5/6

**Step 2: Calculate the amount of lemonade in the pitcher:**

* Total volume: 18 cups
* Lemonade proportion: 5/6
* Volume of lemonade: 18 * (5/6) = 15 cups

Therefore, there are 15 cups of lemonade in the pitcher. The answer is 15.

Q: A deep-sea monster rises from the waters once every hundred years to feast on a ship and sate its hunger. Over three hundred years, it has consumed 847 people. Ships have been built larger over time, so each new ship has twice as many people as the last ship. How many people were on the ship the monster ate in the first hundred years?
A: Let us solve it using algebra. Let x be the number of people on the ship the monster ate in the first hundred years.

The number of people on the ship eaten in the second hundred years is 2x, and in the third hundred years is 4x.

Therefore, the total number of people eaten over three hundred years is x + 2x + 4x = 847.

Combining like terms, we get 7x = 847.

Dividing both sides by 7, we find x = 121.

Therefore, there were 121 people on the ship the monster ate in the first hundred years. The answer is 121."""

加载语言模型加载分词器

Load and prepare your LLM's checkpoint for use with Flax.
Start by loading the weights of your model.

# Load parameters
params = params_lib.load_and_format_params(ckpt_path)
Then load the tokenizer.

加载分词器:使用SentencePiece分词器,从vocab_path加载词汇表。
vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
Finally, build a sampler from the transformer configuration deduced from the checkpoint.

构建采样器

首先,通过TransformerConfig.from_params从检查点推断出转换器配置。然后,使用这个配置创建一个转换器实例。最后,使用sampler_lib.Sampler创建一个具有正确参数形状的采样器。

首先,通过TransformerConfig.from_params从检查点推断出转换器配置。然后,使用这个配置创建一个转换器实例。最后,使用sampler_lib.Sampler创建一个具有正确参数形状的采样器。
transformer_config = transformer_lib.TransformerConfig.from_params(
    params, cache_size=1024)
transformer = transformer_lib.Transformer(transformer_config)

# Create a sampler with the right param shapes for the GSM8K prompt below
sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)

通过一个循环来评估模型在GSM8K测试集上的性能


You should expect a score of 19.86% with the 2B model.

%%time
all_correct = 0
all_responses = {}
short_responses = {}
idx = 0
correct = 0

TEMPLATE = """
Q: {question}
A:"""

for task_id, problem in enumerate(gsm8k_test):

  if task_id in all_responses: continue

  # Print Task ID
  print(f"task_id {task_id}")

  # Formulate and print the full prompt
  full_prompt = (PREAMBLE +'\n\n' + PROMPT + '\n' +
                 TEMPLATE.format(question=problem['question']))
  short_prompt = PREAMBLE +'\n' + TEMPLATE.format(question=problem['question'])

  input_batch = [full_prompt]
  response = sampler(input_strings=input_batch, total_generation_steps=1024)
  print(response.text)

  all_responses[task_id] = response.text[0].split('\nQ:')[0]
  short_responses[task_id] = maybe_remove_comma(find_number(all_responses[task_id]))
  print(f"Short answer: {short_responses[task_id]}")
  try:
    correct += float(maybe_remove_comma(
        find_number(problem['answer']))) == float(short_responses[task_id])
  except:
    correct += maybe_remove_comma(
        find_number(problem['answer'])) == maybe_remove_comma(
            find_number(short_responses[task_id]))
  print('-'*40)
  print(f"Ground truth answer {problem['answer']}")
  print(f"Short ground truth answer {find_number(problem['answer'])}")
  print(f"Correct: {correct} out of {idx+1}")
  print("="*40)
  idx += 1

### minimind LLMs 源码解读分析 #### full_sft.py 文件解析 `full_sft.py` 是一个用于实现基于 PyTorch 的分布式混合精度语言模型全参数训练框架的脚本[^1]。该文件主要关注于如何高效地利用硬件资源,在大规模数据集上进行高效的训练。 为了支持分布式训练,此模块引入了 `torch.distributed.launch` 工具来启动多进程环境,并通过配置 GPU 设备来进行并行计算。对于优化器的选择,默认采用 AdamW 来更新权重参数;同时为了加速收敛过程以及提高数值稳定性,还应用了梯度裁剪技术防止梯度过大造成不稳定现象发生。 此外,考虑到现代深度学习任务中常见的内存瓶颈问题,这里实现了自动混合精度机制 (Automatic Mixed Precision, AMP),它允许网络中的某些部分以较低位宽的数据类型运行从而节省显存空间而不影响最终性能表现。 ```python from torch.cuda.amp import GradScaler, autocast scaler = GradScaler() with autocast(): outputs = model(inputs) loss.backward() scaler.step(optimizer) scaler.update() ``` #### eval.py 文件解析 另一方面,《eval.py》则专注于构建一个可以与用户实时互动交流的人工智能系统[^2]。具体来说就是创建了一个命令行界面(Command Line Interface, CLI), 让使用者能够输入自然语言查询语句得到相应的回复结果。 在这个过程中涉及到的关键组件包括但不限于: - **Tokenizer**: 负责将原始文本转换成 token 序列以便送入 Transformer 编解码架构处理; - **Model Inference Pipeline**: 定义好推理流程之后就可以调用预训练好的 checkpoint 进行预测操作了; - **Response Generation Logic**: 根据上下文信息动态调整生成策略确保对话连贯性和逻辑一致性. ```python tokenizer = AutoTokenizer.from_pretrained('pretrained_model_path') model = AutoModelForCausalLM.from_pretrained('pretrained_model_path') input_text = "你好" inputs = tokenizer(input_text, return_tensors="pt").to(device) outputs = model.generate(**inputs, max_length=50) response = tokenizer.decode(outputs[0], skip_special_tokens=True) print(response) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一个处女座的程序猿

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值