import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.nn as nn
import bitsandbytes as bnb
import transformers
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
"""
opt-6.7b模型,它以float16的精度存储,大小大约为13GB!如果我们使用bitsandbytes库以8位加载它们,我们需要大约7GB的显存
"""
# load_in_8bit=True参数来调用bitsandbytes库进行8位量化
model = AutoModelForCausalLM.from_pretrained("facebook/opt-6.7b",load_in_8bit=True,device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-6.7b")
# 针对所有非int8的模块进行预处理以提升精度
from peft import prepare_model_for_int8_training
model = prepare_model_for_int8_training(model)
# 配置LoRA的参数
from peft import LoraConfig, get_peft_model
config = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
model = get_peft_model(model, config)
def print_trainable_parameters(model):
"""Prints the number of trainable parameters in the model."""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}")
# 加载数据(名人名言数据集作为训练数据)
from datasets import load_dataset
data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
# 训练
trainer = transformers.Trainer(
model=model,
train_dataset=data["train"],
args=transformers.TrainingArguments(per_device_train_batch_size=4,gradient_accumulation_steps=4,warmup_steps=100,max_steps=200,learning_rate=2e-4,fp16=True,logging_steps=1,output_dir="outputs",),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False))
trainer.train()
# 推理
batch = tokenizer("Two things are infinite: ", return_tensors="pt")
with torch.cuda.amp.autocast():
output_tokens = model.generate(**batch, max_new_tokens=50)
print("\n\n", tokenizer.decode(output_tokens[0], skip_special_tokens=True))
参考链接:
1、2023年的深度学习入门指南(12) - PEFT与LoRA-CSDN博客