环境部署
1.创建相关环境,安装依赖
git clone https://github.com/Vision-CAIR/MiniGPT-4.git
cd MiniGPT-4
conda env create -f environment.yml
conda activate minigpt4
2.配置Vicuna权重
参考vicuna权重配置过程,然后修改minigpt4.yaml配置文件,将其修改为vicuna所在的权重文件
minigpt4/configs/models/minigpt4.yaml修改第16行的路径为vicunna所在文件夹
llama_model: "/path/to/vicuna/weights/"
3.准备Mini-GPT4的相关checkpoint
参考链接https://github.com/Vision-CAIR/MiniGPT-4/tree/main,然后修改配置文件
进行图片对话
核心代码
1.加载模型
from minigpt4.conversation.conversation import Chat, CONV_VISION
import argparse
from minigpt4.common.config import Config
from minigpt4.common.registry import registry
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--cfg-path", default="eval_configs/minigpt4_eval.yaml", help="path to configuration file.")
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args([])
print(args)
print('Initializing Chat')
cfg = Config(args)
model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
print('Initialization Finished')
2.将图片转化为向量
chat_state = CONV_VISION.copy()
img_list = []
llm_message = chat.upload_img("/home/xs/图片/截图/cat.jpeg", chat_state, img_list)
[tensor([[[-7.0020e-01, -2.2832e+00, -2.7754e+00, ..., 1.5625e+00,
-1.1436e+00, 8.3496e-01],
[ 3.0430e+00, 1.1338e+00, -2.1106e-01, ..., 1.9983e-01,
-5.9375e+00, -6.4258e-01],
[ 3.1113e+00, -4.3320e+00, -2.0950e-02, ..., -1.3037e+00,
-9.2578e-01, -9.8145e-01],
...,
[-5.4375e+00, 1.7754e+00, 5.2383e+00, ..., 1.7041e+00,
9.7266e+00, 6.5088e-01],
[ 1.7158e+00, 2.9531e+00, 2.4551e+00, ..., 3.8071e-03,
-2.0273e+00, 3.0225e-01],
[ 6.7344e+00, -3.2446e-01, -4.3555e-01, ..., 2.2285e+00,
3.5000e+00, -1.6357e+00]]], device='cuda:0', dtype=torch.float16,
grad_fn=<ViewBackward0>)]
3.输入问题和模型对话
chat.ask("这张图表示了什么内容", chat_state)
chat.answer(conv=chat_state,
img_list=img_list,
num_beams=1,
temperature=0.8,
max_new_tokens=300,
max_length=2000)[0]
'This image shows a black cat with green eyes looking directly at the camera.'