import os
import argparse
import json
import random
from fastchat.serve.inference import generate_stream
from fastchat.conversation import get_conv_template, SeparatorStyle
from fastchat.model.model_adapter import load_model, get_conversation_template
def main(args):
result_path = args.result_path
new_result_path = args.new_result_path
three_path = args.three_path
sub_path = args.sub_path
second_path = args.second_path
model_path = '/vicuna-7b'
device='cuda'
num_gpus = 1
max_gpu_memory = None
load_8bit = False
cpu_offloading = False
debug = False
model, tokenizer = load_model(
model_path, device, num_gpus, max_gpu_memory, load_8bit, cpu_offloading, debug
)
temperature = 0.7
repetition_penalty = 1.0
max_new_tokens = 512
result_path = os.path.join(result_path, sub_path)
new_result_path = os.path.join(new_result_path, sub_path)
total_len = 0
fwalks = os.walk(result_path)
for root, dirs, files in fwalks:
for f in files:
fpath = os.path.join(root, f)
if os.path.exists(fpath):
# read original results
res_dict = json.load(open(fpath, 'r'))
res_list = res_dict['predictions']
total_len += len(res_list)
print(f'{sub_path} - total to process: {total_len}')
generated_len = 0
fwalks = os.walk(result_path)
for root, dirs, files in fwalks:
for f in files:
fpath = os.path.join(root, f)
if os.path.exists(fpath):
# read original results
res_dict_blip = json.load(open(fpath, 'r'))
res_dict_blipv2 = json.load(open(fpath.replace(result_path, os.path.join(second_path, sub_path)), 'r'))
res_dict_flamingo = json.load(open(fpath.replace(result_path, os.path.join(three_path, sub_path)), 'r'))
res_list = res_dict_blip['predictions']
res_list_blipv2 = res_dict_blipv2['predictions']
res_list_flamingo = res_dict_flamingo['predictions']
# make new results
new_res_dir = root.replace(result_path, new_result_path)
if not os.path.exists(new_res_dir):
os.makedirs(new_res_dir)
new_res_fpath = fpath.replace(result_path, new_result_path)
fw = open(new_res_fpath, 'w')
new_res_dict = {}
new_res_list = []
i = 0
# print(len(res_list))
while i < len(res_list):
res = res_list[i]
gametime = res["gameTime"]
cap = res["comment"]
score = res["score"]
res_blipv2 = res_dict_flamingo[i]
cap_blipv2 = res_blipv2["comment"]
score_blipv2 = res["score"]
res_flamingo = res_list_flamingo[i]
cap_flamingo = res_flamingo["comment"]
input_cap = '{} {} {}'.format(cap, cap_blipv2, cap_flamingo)
print('input_cap: {}'.format(input_cap))
# ======= restate captions ======= #
inp = f'Please help me refine this sentence: {input_cap}'
conv = get_conversation_template(model_path)
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
gen_params = {
"model": model_path,
"prompt": prompt,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
"max_new_tokens": max_new_tokens,
"stop": conv.stop_str,
"stop_token_ids": conv.stop_token_ids,
"echo": False,
}
generate_stream_func = generate_stream
output_stream = generate_stream_func(model, tokenizer, gen_params, device)
# print('output_stream: {}\n'.format(list(output_stream)[-1]['text']))
# new_cap = list(output_stream)[-1]['text'].replace('\"','').split('\n')[0].strip()
new_cap_list = list(output_stream)[-1]['text'].replace('\"','').split('\n')
new_cap = None
for sentence in new_cap_list:
if '[PLAYER]' in sentence or '[TEAM]' in sentence:
new_cap = sentence.strip()
break
if new_cap and ':' in new_cap:
new_cap = str(new_cap).strip().split(':')[1] # 过滤掉问句
# save results
res_to_save = {
"gameTime": res_list[i]["gameTime"],
"label": "comments",
"comment": new_cap,
"score": res_list[i]["score"],
}
new_res_list.append(res_to_save)
i += 1
generated_len += 1
print(f'{sub_path} - {generated_len} / {total_len}: {new_cap}\n')
# print(f'{f}: {len(res_list)} {len(new_res_list)}')
if len(new_res_list):
new_res_dict["predictions"] = new_res_list
else:
new_res_dict["predictions"] = []
json.dump(new_res_dict, fw)
fw.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--result-path", type=str, default='', help="caption original path")
parser.add_argument("--second-path", type=str, default='")
parser.add_argument("--new-result-path", type=str, default='', help="new path for caption restated")
parser.add_argument("--sub-path", type=str, default='england_epl', help="league name")
args = parser.parse_args()
main(args)
vicuna集成语句
最新推荐文章于 2024-07-20 23:37:53 发布