visual-genome rcnn features 提取(二)- 提取篇

在完成了caffe的配置后,以及安装完依赖库cython, opencv, pyyaml, easydict

这里首先记录一下easydict的错误(再次强调,这个库一定要装低版本!),在通过pip install easydict后可解决

File "./tools/generate_tsv.py", line 221, in <module> assert cfg.TEST.HAS_RPN
Assertion error: cfg.TEST.HAS_RPN == False

下面进入正题,genome features的提取

  1. 下载好visual-genome的images,分为两组VG_100K与VG_100K_2,两组的image_id无重复,全部放入bottom-up-attention/data/VG_data目录下(这里附上数据链接http://visualgenome.org/
     
  2. 对tools/generate_tsv稍作修改,主要更改58-65行的directory即可,这里我的更改如下:
    with open('./data/visualgenome/image_data.json') as f:
        for item in json.load(f):
            image_id = int(item['image_id'])
    		# filepath = os.path.join('./data/VGdata/', item['url'].split('rak248/')[-1])
    		filepath = os.path.join('./data/VGdata/', str(image_id)+'.jpg') # 这里可直接用作者的那句话
    		# print(filepath, os.path.exits(filepath))
    		split.append((filepath,image_id))

     

  3. 按照作者给的例子执行代码即可,超参数给出如下,采用作者给出的pretrained_model:
    python ./tools/generate_tsv.py --gpu 0 --cfg experiments/cfgs/faster_rcnn_end2end_resnet.yml --def ./models/vg/ResNet-101/faster_rcnn_end2end_final/test.prototxt --out /home/share/bierone/genome_resnet101_faster_rcnn_genome.tsv --net data/faster_rcnn_models/resnet101_faster_rcnn_final.caffemodel --split genome

     

验证:提取完tsv文件后,难免需要进行检查,验证box的位置是否合理,这里附上本人的代码show.py(难免有疏漏之处,希望大家不吝指出):

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
# set display defaults
# plt.rcParams['figure.figsize'] = (10, 10)		   # large images
# plt.rcParams['image.interpolation'] = 'nearest'  # don't interpolate: show square pixels
# plt.rcParams['image.cmap'] = 'gray'  # use grayscale output rather than a (potentially misleading) color heatmap
import numpy as np
import cv2, base64
import csv, sys
csv.field_size_limit(sys.maxsize)

FIELDNAMES = ['image_id', 'image_w','image_h','num_boxes', 'boxes', 'features']
infile = '/home/share/lyb/genome_resnet101_faster_rcnn_genome.tsv'
data_root = '/home/lyb/bottom-up-attention/data/VGdata/'


def get_detections_from_tsv(nums=5):
	in_data = {}
	with open(infile, "r") as tsv_in_file:
		reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES)
		for i, item in enumerate(reader):
			item['image_id'] = (item['image_id'])
			item['image_h'] = int(item['image_h'])
			item['image_w'] = int(item['image_w'])	 
			item['num_boxes'] = int(item['num_boxes'])
			for field in ['boxes', 'features']:
				item[field] = np.frombuffer(base64.decodestring(item[field].encode('utf8')), 
					  dtype=np.float32).reshape((item['num_boxes'],-1))
			# show_features(item['boxes'])
			in_data[i] = item
			if i > nums:
				break
	return in_data

def show_features(ax, boxes, objects='aa', attrs='bb'):
	for i in range(boxes.shape[0]):
		bbox = boxes[i]
		ax.add_patch(
			plt.Rectangle((bbox[0], bbox[1]),
						  bbox[2] - bbox[0],
						  bbox[3] - bbox[1], fill=False,
						  edgecolor='red', linewidth=2, alpha=0.8)
				)
		# plt.gca().text(bbox[0], bbox[1] - 2,
					# '%s' % (cls),
					# bbox=dict(facecolor='blue', alpha=0.5),
					# fontsize=12, color='white')
		plt.axis('off')
		# plt.tight_layout()
		plt.draw()


if __name__ == '__main__':
	in_data = get_detections_from_tsv()
	for key,item in in_data.items():
		# print(item)
		im_file = data_root + item['image_id'] + '.jpg'
		im = cv2.imread(im_file)
		# im = im[:, :, (2, 1, 0)] # RGB reverse channels
		fig, ax = plt.subplots(figsize=(20, 20))
		ax.imshow(im)
		show_features(ax,item['boxes'])
		# rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
		plt.savefig('demo/'+item['image_id']+'.jpg')

 

总结:这里作者的代码写的比较复杂,我只针对部分做了仔细查看,就不附上分析了。其实整个提取过程并不复杂,时间花费主要在配置环境上。

 

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 17
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值