NLP(七十一)大模型微调RACE数据集的进一步实验

欢迎关注我的公众号NLP奇幻之旅,原创技术文章第一时间推送。

欢迎关注我的知识星球“自然语言处理奇幻之旅”,笔者正在努力构建自己的技术社区。

在文章NLP(七十)使用LLAMA 2模型微调Multiple Choice MRC 中,笔者介绍了RACE数据集,以及如何使用LLAMA 2模型对该数据集进行微调,较以往的BERT系列模型取得了长足进步。

本文将继续上述实验。

实验

我们的实验使用Firefly大模型训练框架,数据集为RACE middle数据集,训练参数在NLP(七十)使用LLAMA 2模型微调Multiple Choice MRC 中已经给出。

微调方式采用大模型 + qlora, 对数据集进行SFT(Supervised Fine-Tuning)。针对不同的模型、模型尺寸、学习率、最大长度等参数进行调试(其余参数不变),实验结果如下:

模型学习率训练轮数最大长度准确率
LLAMA-2-7B1e-433840.8691
LLAMA-2-7B1e-433200.8593
LLAMA-2-7B1e-453840.8545
LLAMA-2-7B1e-453200.8538
LLAMA-2-13B1e-433840.8844
Baichuan-7B2e-433840.8357
Baichuan-13B-Chat1e-433840.8726
Baichuan2-13B-Base2e-433840.8948
XVERSE-13B1e-433840.8718

在上述实验中,LLAMA-2-13B, Baichuan-13B-Chat, Baichuan2-13B-Base, XVERSE-13B模型取得了不错的效果,尤其Baichuan2-13B-Base效果最好,accuracy达到了89.48%。

我们对LLAMA-2-13B + Baichuan2-13B-Base + XVERSE-13B进行模型集成(取预测结果中的次数最大值,如果都相同,则以Baichuan2-13B-Base的结果为准),accuracy为90.25%!

使用全量RACE数据中的训练集进行训练,使用LLAMA-2-13B,学习率1e-4,轮数3,最大长度384,测试结果如下:

测试集准确率
middle0.9283
high0.8413
all0.8666

我们再回过头来看看,在以往BERT时代,RACE数据集排行榜

可视化

我们使用Gradio模块,对文章、问题、选项进行可视化问答,页面效果:

实现的Python代码可以参考如下:

# -*- coding: utf-8 -*-
import gradio as gr
from transformers import AutoTokenizer
import torch


import sys
sys.path.append("../../")
from component.utils import ModelUtils

# 使用合并后的模型进行推理
model_name_or_path = '/home/jclian91/experiment/Firefly/script/checkpoint/firefly-llama2-7b-qlora-sft-race-merge'
# 生成超参配置
max_new_tokens = 1
top_p = 0.9
temperature = 0.01
repetition_penalty = 1.0
device = 'cuda:0'
# 加载模型
model = ModelUtils.load_model(
    model_name_or_path,
    load_in_4bit=False,
    adapter_name_or_path=None
).eval()
tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
    trust_remote_code=True,
    # llama不支持fast
    use_fast=False if model.config.model_type == 'llama' else True
)
print(f"load model: {model_name_or_path}")


# Gradio app
def predict(passage, question, options):
	# make prompt
    prefix = 'Read the following passage and questions, then choose the right answer from options, ' \
             'the answer should be one of A, B, C, D.\n\n'
    passage = f'<passage>:\n{passage}\n\n'
    question = f'<question>:\n{question}\n\n'
    option1, option2, option3, option3 = options.split("\n")
    option = f'<options>:\n{option1.strip()}\n{option2.strip()}\n{option3.strip()}\n{option3.strip()}\n\n'
    suffix = f"<answer>:\n"
    prompt = ''.join([prefix, passage, question, option, suffix])
    # get input ids
    input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
    bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(device)
    eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long).to(device)
    input_ids = torch.concat([bos_token_id, input_ids, eos_token_id], dim=1)
    # model predict
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True,
            top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty,
            eos_token_id=tokenizer.eos_token_id
        )
    outputs = outputs.tolist()[0][len(input_ids[0]):]
    response = tokenizer.decode(outputs)
    response = response.strip().replace(tokenizer.eos_token, "").strip()
    return f"The answer is {response}."


with gr.Blocks() as demo:
    # 设置输入组件
    gr_passage = gr.Textbox(lines=3, placeholder="Passage", label="Passage")
    gr_question = gr.Textbox(lines=1, placeholder="question", label="question")
    gr_options = gr.Textbox(lines=4, placeholder="options", label="options")
    # 设置输出组件
    answer = gr.Textbox(label="Answer")
    # 设置按钮
    greet_btn = gr.Button("Show me the answer")
    # 设置按钮点击事件
    greet_btn.click(fn=predict,
                    inputs=[gr_passage, gr_question, gr_options],
                    outputs=answer)

demo.launch(share=True)

总结

本文并无太多新意,只是在之前文章的基础上进行大量测试,考察不同模型,模型尺寸在SFT上的效果差异,同时也验证了,大模型时代中LLM的强大之处,在效果上几乎横扫以往所有的NLP模型,且将以往的NLP任务做到了统一,这才是LLM的可怕之处!

本文的代码及实验结果已开放至Github,网址为 https://github.com/percent4/llama-2-multiple-choice-mrc .

本人博客网站为 https://percent4.github.io/ ,欢迎大家访问~

推荐阅读
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值