最近很多人需要关于minigpt4/minigpt-v2批量推理/测试的代码,而且不需要gradio。我来贡献一下我写的,其实特别简单,就是把gradio那边改了就可以:
单张测试
# 前面不用改
# ========================================
# Model Initialization
# ========================================
conv_dict = {'pretrain_vicuna0': CONV_VISION_Vicuna0,
'pretrain_llama2': CONV_VISION_LLama2}
print('Initializing Chat')
args = parse_args()
cfg = Config(args)
model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
CONV_VISION = conv_dict[model_config.model_type]
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
stop_words_ids = [[835], [2277, 29937]]
stop_words_ids = [torch.tensor(ids).to(device='cuda:{}'.format(args.gpu_id)) for ids in stop_words_ids]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), stopping_criteria=stopping_criteria)
print('Initialization Finished')
# ================测试一张=========================
# text input
USER_ANS_FORMAT = ''
# 新代码
# 上传图片
# 图片地址
gr_img = ''
chat_state = CONV_VISION.copy()
img_list = []
# user_message = '[grounding] describe this image in detail'
user_message = '[grounding]' + USER_ANS_FORMAT + USER_PROMPT_CUB_NOVEL
chat.upload_img(gr_img, chat_state, img_list)
# 问问题
chat.ask(user_message, chat_state)
chat.encode_img(img_list)
# 得到回答
llm_message = chat.answer(conv=chat_state,
img_list=img_list,
temperature=1.5,
max_new_tokens=1000,
max_length=2000)[0]
print(llm_message)
批量测试
# ========================================
# Model Initialization
# ========================================
conv_dict = {'pretrain_vicuna0': CONV_VISION_Vicuna0,
'pretrain_llama2': CONV_VISION_LLama2}
print('Initializing Chat')
args = parse_args()
cfg = Config(args)
model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
CONV_VISION = conv_dict[model_config.model_type]
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
stop_words_ids = [[835], [2277, 29937]]
stop_words_ids = [torch.tensor(ids).to(device='cuda:{}'.format(args.gpu_id)) for ids in stop_words_ids]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), stopping_criteria=stopping_criteria)
print('Initialization Finished')
# ==========================评估=======================
# 读取图像地址和标签
root_dir_base = ''
labels = []
# labels_2 = []
image_paths = []
# 这部分,你可以用Dataset构建一下,其实是一样的
for label in os.listdir(root_dir_base):
label_dir = os.path.join(root_dir_novel, label)
for image_file in os.listdir(label_dir):
image_path = os.path.join(label_dir, image_file)
image_paths.append(image_path)
labels.append(label.split('.')[-1])
# labels_2.append(label)
# 这里填你的input
USER_ANS_FORMAT = ''
chat_state = CONV_VISION.copy()
right_count = 0
for i in range(len(image_paths)):
# for i in range(10):
gr_img = image_paths[i]
label_num = os.path.basename(os.path.dirname(image_path))
label = label_num.split('.')[-1]
# label = labels[i]
# label_2 = labels_2[i]
img_list = []
# user_message = '[grounding] describe this image in detail'
user_message = '[grounding]' + USER_ANS_FORMAT
# 上传图像
chat.upload_img(gr_img, chat_state, img_list)
# 问问题
chat.ask(user_message, chat_state)
chat.encode_img(img_list)
# 得到回答
llm_message = chat.answer(conv=chat_state,
img_list=img_list,
temperature=1.5,
max_new_tokens=1000,
max_length=2000)[0]
print('Ans: ' + llm_message + ', Ground_truth: ' + label_num)
if label in llm_message:
right_count += 1
# reset
chat_state.messages = []
img_list = []
print('Total: %d, Right_count: %d, Acc: %.3f'%(len(image_paths), right_count, right_count / len(image_paths)))