# -*- coding: utf-8 -*-
"""
Created on Fri Apr 14 15:02:33 2023
@author: wang.junjun3
"""
# 加载模型
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("ClueAI/PromptCLUE-base",cache_dir=r'C:\Users\wang.junjun3\model')
model = T5ForConditionalGeneration.from_pretrained("ClueAI/PromptCLUE-base",cache_dir=r'C:\Users\wang.junjun3\model')
import torch
device = torch.device('cpu')
# device = torch.device('cuda')
model.to(device)
def preprocess(text):
return text.replace("\n", "_")
def postprocess(text):
return text.replace("_", "\n")
def answer(text, sample=False, top_p=0.8):
'''sample:是否抽样。生成任务,可以设置为True;
top_p:0-1之间,生成的内容越多样'''
text = preprocess(text)
encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=768, return_tensors="pt").to(device)
if not sample:
out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_length=128, num_beams=4, length_penalty=0.6)
else:
out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_length=64, do_sample=True, top_p=top_p)
out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
return postprocess(out_text[0])
text = '为下面的文章生成摘要:你们有没有感觉这个车子空调功率有问题 手机远程启动空调,或者不猜刹车只按启动键,这2种情况下得空调效果很弱,制冷效果差。踩住刹车再按启动键(可以开始驾驶),空调的风速和出风温度立马就加强了,感觉是之前只用了-小半功率在运行。问4S店,说这个是正常现象,有没有其他车友也遇到这种情况的'
res = answer(text)
print(res)
参考 中文语言理解测评基准(CLUE)
# -*- coding: utf-8 -*-
"""
Created on Fri Apr 14 14:39:26 2023
@author: wang.junjun3
"""
import clueai
# initialize the Clueai Client with an API Key
cl = clueai.Client("", check_api_key=False)
prompt= '''
摘要:
你们有没有感觉这个车子空调功率有问题 手机远程启动空调,或者不猜刹车只按启动键,这2种情况下得空调效果很弱,制冷效果差。踩住刹车再按启动键(可以开始驾驶),空调的风速和出风温度立马就加强了,感觉是之前只用了-小半功率在运行。问4S店,说这个是正常现象,有没有其他车友也遇到这种情况的
答案:
'''
# generate a prediction for a prompt
generate_config = {
"do_sample": True,
"top_p": 0.8,
"max_length": 128,
"min_length": 10,
"length_penalty": 1.0,
"num_beams": 1
}
# 如果需要自由调整参数自由采样生成,添加额外参数信息设置方式:generate_config=generate_config
prediction = cl.generate(
model_name='clueai-base',
prompt=prompt)
# 需要返回得分的话,指定return_likelihoods="GENERATION"
# print the predicted text
print('prediction: {}'.format(prediction.generations[0].text))