glm推理

import os
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer
 
# 载入Tokenizer
tokenizer = AutoTokenizer.from_pretrained("chatglm-6b", trust_remote_code=True)
# 载入模型
config = AutoConfig.from_pretrained("chatglm-6b", trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained("chatglm-6b", config=config, trust_remote_code=True)
CHECKPOINT_PATH='output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000'
prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
    if k.startswith("transformer.prefix_encoder."):
        new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
 
# Comment out the following line if you don't use quantization
model = model.quantize(4)
model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()
 
response, history = model.chat(tokenizer, "类型#工装裤*颜色#深蓝色*图案#条纹*裤长#八分裤", history=[])
print(response)


'''
import json

def generate_summaries(test_set):
    for item in test_set:
        # 将content字段作为输入,调用模型生成摘要
        response, history = model.chat(tokenizer, item["content"], history=[])
        # 直接将生成的摘要存储在原item的summary字段
        item["summary"] = response
    return test_set

# 调用函数生成摘要并更新测试集
updated_test_set = generate_summaries(test_set)

# 打印结果
for item in updated_test_set:
    print(item)

# 选择一个文件名和路径来保存 JSON 数据
file_path = 'updated_test_set.json'

# 使用 json 模块将数据写入文件
with open(file_path, 'w', encoding='utf-8') as json_file:
    json.dump(updated_test_set, json_file, ensure_ascii=False, indent=4)

print(f"Updated test set has been saved to {file_path}")
'''

  • 5
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值