BLIP2复现
直接给出我可以运行的代码:
import torch
from datasets import load_dataset
from lavis.models import load_model_and_preprocess
from matplotlib import pyplot as plt
from PIL import Image
import requests
from torch.cuda.amp import autocast
# 设置设备
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
# 加载BLIP-2预训练模型
model, vis_processors, txt_processors = load_model_and_preprocess(
name="blip2_feature_extractor",
model_type="pretrain_vitL",
is_eval=True,
device=device
)
# 下载并读取图像
img_url = 'https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png'
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
# 显示原始图像
plt.imshow(raw_image.resize((596, 437)))
plt.axis('off')
plt.show()
# 预处理图像
image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
# 使用 autocast 自动处理精度
with autocast(dtype=torch.float16):
# 生成结果
result = model.generate({"image": image})
# 打印生成结果
print(result)
reference:BLIP2