minigpt4/minigpt-v2批量推理/测试的代码

最近很多人需要关于minigpt4/minigpt-v2批量推理/测试的代码,而且不需要gradio。我来贡献一下我写的,其实特别简单,就是把gradio那边改了就可以:

单张测试

# 前面不用改

# ========================================
#             Model Initialization
# ========================================

conv_dict = {'pretrain_vicuna0': CONV_VISION_Vicuna0,
             'pretrain_llama2': CONV_VISION_LLama2}

print('Initializing Chat')
args = parse_args()
cfg = Config(args)

model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))

CONV_VISION = conv_dict[model_config.model_type]

vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)

stop_words_ids = [[835], [2277, 29937]]
stop_words_ids = [torch.tensor(ids).to(device='cuda:{}'.format(args.gpu_id)) for ids in stop_words_ids]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])

chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), stopping_criteria=stopping_criteria)
print('Initialization Finished')

# ================测试一张=========================
# text input
USER_ANS_FORMAT = ''
# 新代码
# 上传图片
# 图片地址
gr_img = ''
chat_state = CONV_VISION.copy()
img_list = []
# user_message = '[grounding] describe this image in detail'
user_message = '[grounding]' + USER_ANS_FORMAT + USER_PROMPT_CUB_NOVEL

chat.upload_img(gr_img, chat_state, img_list)

# 问问题
chat.ask(user_message, chat_state)
chat.encode_img(img_list)

# 得到回答
llm_message = chat.answer(conv=chat_state,
                          img_list=img_list,
                          temperature=1.5,
                          max_new_tokens=1000,
                          max_length=2000)[0]
print(llm_message)

批量测试

# ========================================
#             Model Initialization
# ========================================

conv_dict = {'pretrain_vicuna0': CONV_VISION_Vicuna0,
             'pretrain_llama2': CONV_VISION_LLama2}

print('Initializing Chat')
args = parse_args()
cfg = Config(args)

model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))

CONV_VISION = conv_dict[model_config.model_type]

vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)

stop_words_ids = [[835], [2277, 29937]]
stop_words_ids = [torch.tensor(ids).to(device='cuda:{}'.format(args.gpu_id)) for ids in stop_words_ids]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])

chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), stopping_criteria=stopping_criteria)
print('Initialization Finished')


# ==========================评估=======================

# 读取图像地址和标签
root_dir_base = ''

labels = []
# labels_2 = []
image_paths = []
# 这部分,你可以用Dataset构建一下,其实是一样的
for label in os.listdir(root_dir_base):
    label_dir = os.path.join(root_dir_novel, label)
    for image_file in os.listdir(label_dir):
        image_path = os.path.join(label_dir, image_file)
        image_paths.append(image_path)
    labels.append(label.split('.')[-1])
    # labels_2.append(label)

# 这里填你的input
USER_ANS_FORMAT = ''

chat_state = CONV_VISION.copy()
right_count = 0
for i in range(len(image_paths)):
# for i in range(10):
    gr_img = image_paths[i]
    label_num = os.path.basename(os.path.dirname(image_path))
    label = label_num.split('.')[-1]
    # label = labels[i]
    # label_2 = labels_2[i]
    img_list = []
    # user_message = '[grounding] describe this image in detail'
    user_message = '[grounding]' + USER_ANS_FORMAT
    # 上传图像
    chat.upload_img(gr_img, chat_state, img_list)

    # 问问题
    chat.ask(user_message, chat_state)
    chat.encode_img(img_list)

    # 得到回答
    llm_message = chat.answer(conv=chat_state,
                          img_list=img_list,
                          temperature=1.5,
                          max_new_tokens=1000,
                          max_length=2000)[0]
    print('Ans: ' + llm_message + ', Ground_truth: ' + label_num)
    if label in llm_message:
        right_count += 1
    
    # reset
    chat_state.messages = []
    img_list = []

print('Total: %d, Right_count: %d, Acc: %.3f'%(len(image_paths), right_count, right_count / len(image_paths)))

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值