1.从github上下载blip源码
2.下载vqav2数据集 https://visualqa.org/download.html
将json文件和训练集、测试集、验证集都下载
(我是把测试集和验证集的图片都复制到训练集train2014里面了,因为程序报错找不到图片)
3.下载Visual Genome数据集 https://homes.cs.washington.edu/~ranjay/visualgenome/api.html
我下的是2016年,分为两个部分。一个9.2G,一个5.47G
(同样的,我也把数据集合二为一,因为程序报错找不到图片)
4.下载复现vqa所需要用到的模型https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth
(也可以在remade里面找一下)
5.修改configs中vqa.yaml的vqa_root路径(此处为vqav2的数据集路径)
6.修改configs中vqa.yaml的vg_root路径(此处为VisualGenome数据集路径)
7.修改train_vqa.py中parser.add_argument(‘–config’, default=‘此处为你自己的vqa.yaml的路径’)
8.修改parser.add_argument(‘–output_dir’, default=‘output/VQA’)
大功告成,直接运行train_vqa.py
以上就可以训练完成,接下来教大家如何进行对话。(也可以看remade中的demo)
先创建一个py文件,将文件放进blip源码下一级,并将下列代码放入py文件中(注意修改我标注的地方)
from models.blip_vqa import blip_vqa
from PIL import Image
import requests
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_demo_image(image_size,device):
img_url = '此处为图片的链接,注意,是链接不是路径'
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
w,h = raw_image.size
resized_image = raw_image.resize((w // 5, h // 5))
transform = transforms.Compose([
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
image = transform(resized_image).unsqueeze(0).to(device)
return image
image_size = 480
image = load_demo_image(image_size=image_size, device=device)
model_url = '这里可以照着demo中改成预训练模型'
model = blip_vqa(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device)
question = '复制图片链接后,这里问你想问的问题,中英文都支持'
with torch.no_grad():
answer = model(image, question, train=False, inference='generate')
print('answer: '+answer[0])