vicuna集成语句

import os
import argparse
import json
import random

from fastchat.serve.inference import generate_stream
from fastchat.conversation import get_conv_template, SeparatorStyle
from fastchat.model.model_adapter import load_model, get_conversation_template

def main(args):
    result_path = args.result_path
    new_result_path = args.new_result_path
    three_path = args.three_path
    sub_path = args.sub_path
    second_path = args.second_path

    model_path = '/vicuna-7b'
    device='cuda'

    num_gpus = 1
    max_gpu_memory = None
    load_8bit = False
    cpu_offloading = False
    debug = False

    model, tokenizer = load_model(
        model_path, device, num_gpus, max_gpu_memory, load_8bit, cpu_offloading, debug
    )

    temperature = 0.7
    repetition_penalty = 1.0
    max_new_tokens = 512

    result_path = os.path.join(result_path, sub_path)
    new_result_path = os.path.join(new_result_path, sub_path)

    total_len = 0
    fwalks = os.walk(result_path)
    for root, dirs, files in fwalks:
        for f in files:
            fpath = os.path.join(root, f)
            if os.path.exists(fpath):
                # read original results
                res_dict = json.load(open(fpath, 'r'))
                res_list = res_dict['predictions']
                total_len += len(res_list)
    print(f'{sub_path} - total to process: {total_len}')

    generated_len = 0
    fwalks = os.walk(result_path)
    for root, dirs, files in fwalks:
        for f in files:
            fpath = os.path.join(root, f)
            if os.path.exists(fpath):
                # read original results
                res_dict_blip = json.load(open(fpath, 'r'))
                res_dict_blipv2 = json.load(open(fpath.replace(result_path, os.path.join(second_path, sub_path)), 'r'))
                res_dict_flamingo = json.load(open(fpath.replace(result_path, os.path.join(three_path, sub_path)), 'r'))

                res_list = res_dict_blip['predictions']
                res_list_blipv2 = res_dict_blipv2['predictions']
                res_list_flamingo = res_dict_flamingo['predictions']

                # make new results
                new_res_dir = root.replace(result_path, new_result_path)
                if not os.path.exists(new_res_dir):
                    os.makedirs(new_res_dir)
                new_res_fpath = fpath.replace(result_path, new_result_path)

                fw = open(new_res_fpath, 'w')
                new_res_dict = {}
                new_res_list = []
                i = 0
                # print(len(res_list))
                while i < len(res_list):
                    res = res_list[i]
                    gametime = res["gameTime"]
                    cap = res["comment"]
                    score = res["score"]

                    res_blipv2 = res_dict_flamingo[i]
                    cap_blipv2 = res_blipv2["comment"]
                    score_blipv2 = res["score"]

                    res_flamingo = res_list_flamingo[i]
                    cap_flamingo = res_flamingo["comment"]

                    input_cap = '{} {} {}'.format(cap, cap_blipv2, cap_flamingo)

                    print('input_cap: {}'.format(input_cap))
                    # ======= restate captions ======= #
                    inp = f'Please help me refine this sentence: {input_cap}'
                    conv = get_conversation_template(model_path)
                    conv.append_message(conv.roles[0], inp)
                    conv.append_message(conv.roles[1], None)

                    prompt = conv.get_prompt()

                    gen_params = {
                        "model": model_path,
                        "prompt": prompt,
                        "temperature": temperature,
                        "repetition_penalty": repetition_penalty,
                        "max_new_tokens": max_new_tokens,
                        "stop": conv.stop_str,
                        "stop_token_ids": conv.stop_token_ids,
                        "echo": False,
                    }
                    generate_stream_func = generate_stream
                    output_stream = generate_stream_func(model, tokenizer, gen_params, device)

                    # print('output_stream: {}\n'.format(list(output_stream)[-1]['text']))

                    # new_cap = list(output_stream)[-1]['text'].replace('\"','').split('\n')[0].strip()
                    new_cap_list = list(output_stream)[-1]['text'].replace('\"','').split('\n')

                    new_cap = None
                    for sentence in new_cap_list:
                        if '[PLAYER]' in sentence or '[TEAM]' in sentence:
                            new_cap = sentence.strip()
                            break
                    
                    if new_cap and ':' in new_cap:
                        new_cap = str(new_cap).strip().split(':')[1]   # 过滤掉问句

                    # save results
                    res_to_save = {
                        "gameTime": res_list[i]["gameTime"], 
                        "label": "comments", 
                        "comment": new_cap, 
                        "score": res_list[i]["score"], 
                    }
                    new_res_list.append(res_to_save)
                    i += 1
                    generated_len += 1
                    print(f'{sub_path} - {generated_len} / {total_len}: {new_cap}\n')
                # print(f'{f}: {len(res_list)} {len(new_res_list)}')

                if len(new_res_list):
                    new_res_dict["predictions"] = new_res_list
                else:
                    new_res_dict["predictions"] = []
                json.dump(new_res_dict, fw)
                fw.close()


if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument("--result-path", type=str, default='', help="caption original path")
    parser.add_argument("--second-path", type=str, default='")
    parser.add_argument("--new-result-path", type=str, default='', help="new path for caption restated")
    parser.add_argument("--sub-path", type=str, default='england_epl', help="league name")
    
    args = parser.parse_args()
    main(args)
        
        

GitHub - lm-sys/FastChat: An open platform for training, serving, and evaluating large language models. Release repo for Vicuna and FastChat-T5.

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值