书接上集, 接下来我将教大家如何手动加载llama3并且使用它
下载模型
模型可以从huggingface上下下载:
meta-llama/Meta-Llama-3-8B-Instruct at main
-
consolidated.00.pth: llama3 模型参数文件
-
params.json: 超参文件
-
tokenizer.model: tokenizer 模型参数文件
这三个文件下载到一个目录中
加载llama3
二话不说, 先上代码
MODEL_DIR = "D:\AI\meta-llama\Meta-Llama-3-8B-Instruct\original"
MAX_SEQ_LEN = 512
MAX_BATCH_SIZE = 1
SEED = 42
TEMPERATURE = 0.6
TOP_P = 0.9
DEVICE = "cuda" if torch.cuda.is_available else "cpu"
- MODEL_DIR: 模型的路径, 上面下载的文件存放的目录
- MAX_SEQ_LEN: 一次推理能处理的最长序列长度
- MAX_BATCH_SIZE: 批处理大小
- SEED: 随机数种子
- TEMPERATURE: 温度值
- TOP_P: top p算法阈值
- DEVICE: 主体设配
def loadModel():
model_path = os.path.join(MODEL_DIR, "consolidated.00.pth")
params_path = os.path.join(MODEL_DIR, "params.json")
tokenizer_path = os.path.join(MODEL_DIR, "tokenizer.model")
torch.manual_seed(SEED)
start_time = time.time()
tokenizer = Tokenizer(model_path=tokenizer_path)
with open(params_path, "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=MAX_SEQ_LEN,
max_batch_size=MAX_BATCH_SIZE,
**params,
)
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
checkpoint = torch.load(model_path, map_location="cpu")
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return model, tokenizer
这个没什么讲的, 就是加载模型, 值得讲下的是torch.load
torch.load 首先在CPU中反序列化存储的模型参数, 然后再搬运到map_location指定的设备中,如果是CPU,说明反序列化后就不用搬运。
模型load起来后, 接下来就是和他聊天了
chat
首先建一个轮询, 等待用户输入prompt
def chatLoop(model, tokenizer):
while True:
prompt = input("\33[32mprompt: ") ##要想生活过的去, prompt必须来点绿
if prompt == "exit":
break
chat(model, tokenizer, prompt)
input 是接受用户输入函数,“\33[32m” 是设置显示字体的颜色
def chat(model, tokenizer, prompt):
formatter = ChatFormat(tokenizer)
dialog = [{"role": "user", "content": prompt}]
prompt_token = formatter.encode_dialog_prompt(dialog)
prompt_len = len(prompt_token)
if prompt_len >= MAX_SEQ_LEN:
print(f"\33[31m你丫问题太长了, 我很难办!\033[0m")
return
total_len = MAX_SEQ_LEN
pad_id = tokenizer.pad_id
size = (1, total_len) # batch_size = 1
token_list = torch.full(size, pad_id, dtype=torch.long, device=DEVICE)
token_list[0, :prompt_len] = torch.tensor(prompt_token, dtype=torch.long, device=DEVICE)
prev_pos = 0
stop_tokens = list(tokenizer.stop_tokens)
bstr = b""
print("\033[34mLlama3: ", end='', flush=True)
for cur_pos in range(prompt_len, total_len):
logits = model.forward(token_list[:, prev_pos:cur_pos], prev_pos)
next_token = select_next_token(logits)
token_list[:, cur_pos] = next_token
if next_token[0] in stop_tokens:
break
ansbytes = tokenizer.decode_single_token_bytes(next_token)
bstr += ansbytes
s = decodeUtf8(bstr)
if s:
print(s, end='', flush=True)
bstr = b""
prev_pos = cur_pos
print("\033[0m")
ChatFormat 把prompt格式化成特定格式, 他的encode_dialog_prompt函数还把prompt给token化了, 输出token 列表
接着判断这个列表的长度, 太长了会让模型很难办
然后创建一个GPU tensor, 3个维度, 第一维是batch, 大小为1, 第二维是序列长度, 大小为MAX_SEQ_LEN, 第三维是token在词表中的位置。这个tensor初始化填充tokenizer.pad_id
然后把之前输出的prompt_token复制到token_list
然后一个for循环, 一个token一个token的生成答案
logits = model.forward(token_list[:, prev_pos:cur_pos], prev_pos)
调用模型, 生成logits,token_list在GPU中, model 加载的时候在CPU中, 计算的时候应该先转移到GPU中, 然后再计算,这是灵哥的合理化推测(resonable doubt)
next_token = select_next_token(logits)
根据logits, 使用top p算法,得到下一个token: next_token.
def select_next_token(logits):
if TEMPERATURE > 0:
probs = torch.softmax(logits[:, -1] / TEMPERATURE, dim=-1)
next_token = sample_top_p(probs, TOP_P)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
return next_token
top p算法灵哥之前讲过, 可以看灵哥以前的文章
token_list[:, cur_pos] = next_token
再把next_token添加到token_list的末尾,以完成一轮自回归
if next_token[0] in stop_tokens:
break
判断next_token是否是stop_tokens中的一个, 如果是,则模型回答问题完毕, 退出循环
ansbytes = tokenizer.decode_single_token_bytes(next_token)
bstr += ansbytes
s = decodeUtf8(bstr)
if s:
print(s, end='', flush=True)
bstr = b""
将next_token解码成utf-8编码的字节, 有的token对应的不是一个完成的utf-8字符, 是其中的一部分
s = decodeUtf8(bst)
def decodeUtf8(bstr):
try:
return bstr.decode("utf-8")
except UnicodeDecodeError:
return None
如果可以成功解码, 说明是一个完整的utf-8字符, 然后显示这个字符
如果不能成功编码, 说明不是一个完整的utf-8字符, 先把字节保存到bstr中, 然后等待下一个token的到来
运行效果
代码地址
代码存放在灵哥的github里, 有兴趣的同学可以下载下来玩玩
请大家转发点赞和关注, 以支持灵哥的创作