reid 贴图效果
代码功能:
对reid图片进行分割,并粘贴到目标检测图像中,算法可实现根据检测图像真值,选择合理数量的reid图片,进行算法缩放后进行粘贴
A. funtion statement:
1、iou_compute_getcenter(): return just one center coordinate: (x,y) or None if try 10 times but still cannot find a center
2、reid_img_select(): from /home/cjs/reid_img_root.txt get corrisponding number of reid image root
3、seamlesscloned(): using mask segmented by BASNet to achieve seaming the reid images onto detection images
4、class ReidAttach:
- annotation_update(): update annotation of reid info. into gt of detection annotation
- reid_attach(): attach all the selected reid images into one detection image in 80000 images
- main(): main function
- forward(): send all 80000 into main() funtion one by one
B. root:(file path and root can be changed according to your mission)
- --reidfile_root: '/home/cjs/reid_img_root.txt' 所有reid 图片路径的集合 .txt文件
- --save_root: '/home/cjs/attached_images/' reid贴图后的保存位置
- --cache_path: '/home/cjs/data/100train.cache' 所有detection imags的路径集合 .cache文件
- --segmodel_root: "/home/cjs/SmallObjectDetection/BASNet/saved_models/basnet_bsi/basnet.pth") 分割算法的weights .pth文件
C. python code
# Standard imports
import cv2
import torch
import numpy as np
import math
from BASNet import basnet_test as bas
from BASNet.model.BASNet import BASNet
import argparse
import pickle
import copy
def centerhw2xymaxmin(centerx, centery, h, w):
xmin = centerx - w/2
xmax = centerx + w/2
ymin = centery - h/2
ymax = centery + h/2
# area = h * w
return xmin, xmax, ymin, ymax
def iou_compute_getcenter(gt_datas, dst_size, reid_size, scale):
'''
:param reid_num: number of reid img attach into a detect img
:param gts_data: all the gt of the detect img, list type
:param dst_size: detect img size
:param reid_size: reid img size
:param scale: scale of hw
:return: center of reid img in the detect img
'''
try_num = 0 # the time of trying to find a qualified center
while try_num <= 10: # try 10 times for every pic
# random center of reid img
center_x, center_y = int(np.random.random(1) * dst_size[1]), int(np.random.random(1) * dst_size[0])
xminreid, xmaxreid, yminreid, ymaxreid = centerhw2xymaxmin(
center_x, center_y, reid_size[0] * scale, reid_size[1] * scale)
# judge if the reid image is out of detect image edge
if xminreid < 0 or yminreid < 0 or xmaxreid > dst_size[1] or ymaxreid > dst_size[0]:
continue
# if gt_datas is []
elif len(gt_datas) == 0:
return center_x, center_y # one reid image one center
for gt in gt_datas:
xmin, ymin, w, h = gt.x_top_left, gt.y_top_left, gt.width, gt.height
xmax, ymax = xmin+w, ymin+h
# judge if the attached reid image is overlap with detect image
left_line = max(xmin, xminreid)
right_line = min(xmax, xmaxreid)
top_line = max(ymin, yminreid)
bottom_line = min(ymax, ymaxreid)
if left_line >= right_line or top_line >= bottom_line:
if gt == gt_datas[-1]:
return center_x, center_y # one reid image one center
else:
try_num += 1
break
return None
def reid_img_select(num, reidfile_root):
f = open(reidfile_root, 'r')
lines = f.readlines()
num_lines = len(lines)
reimgroot_list = []
for i in range(num):
random_num = np.random.randint(num_lines)
line = lines[random_num]
reimgroot_list.append(line[:-1])
return reimgroot_list
def seamlessclone(dst, reidimg, src_mask, center, scale):
# attach reid image onto detect image and save the new pic
# detimg_dir = '/data/train/animal_01/images/vcm_hs_20190222_car1.mp4_000125.jpg'
# reidimg_dir = '/data/ped_reid/ped_reid/F/0/0_0.jpg'
# reidimg_dir = "/home/cjs/BASNet/test_data/test_images/opencv-seamless-cloning-example.jpg"
# reidimg_dir ="/home/cjs/BASNet/test_data/test_images/2_3.jpg"
# Read images
# dst = cv2.imread(detimg_dir)
reidimg = cv2.resize(reidimg, (int(scale * reidimg.shape[1]), int(scale * reidimg.shape[0])), interpolation=cv2.INTER_CUBIC)
src_mask = cv2.resize(src_mask, (int(scale * src_mask.shape[1]), int(scale * src_mask.shape[0])), interpolation=cv2.INTER_CUBIC)
h, w = src_mask.shape[0], src_mask.shape[1]
# attach reid image into detect image
dst_region = dst[int(center[1]-h/2):int(center[1]+h/2), int(center[0]-w/2):int(center[0]+w/2), :]
dst_region[src_mask != 0] = 0
reidimg[src_mask == 0] = 0
dst_region += reidimg
return dst
class ReidAttach:
def __init__(self, opt):
self.cache_dict = {}
f1 = open(opt.cache_path, 'rb')
# f2 = open(opt.cacheval_path, 'rb')
train_line = pickle.load(f1)
# val_line = pickle.load(f2)
keys = train_line.keys()
for i, key in enumerate(keys):
self.cache_dict[key] = copy.deepcopy(train_line[key])
if i == 0 and len(train_line[key]) != 0:
self.sample = train_line[key][0]
f1.close()
# keys = val_line.keys()
# for key in keys:
# self.val_dict[key] = copy.deepcopy(val_line[key])
# f2.close()
print('---BASNet loading---')
self.model = BASNet(3, 1)
self.model.load_state_dict(torch.load(opt.segmodel_root))
print('--Model load finished--')
def annotation_update(self, dst_img_root, reid_info):
# update reid info into gt of detect image annotation:
# [class_label, object_id, height, width, x_top_left, y_top_left]
temp = copy.deepcopy(self.sample)
temp.class_label = reid_info[4]
temp.object_id = reid_info[3]
temp.height = int(reid_info[0].shape[0] * reid_info[2])
temp.width = int(reid_info[0].shape[1] * 2)
temp.x_top_left = int(reid_info[1][0] - temp.width/2)
temp.y_top_left = int(reid_info[1][1] - temp.height/2)
self.cache_dict[dst_img_root].append(temp)
def reid_attach(self, opt, dst_img_root, reid_infos):
# attach reid image to dst image and update the id info into the gt
"""
:param opt: include root to save attached image: attached train
:param dst_img_root: destination image root
:param reid_infos: 0.reid_img, 1.center, 2.scale 3.reid_id 4.reid_cls
:return: attached image, updated gt
"""
print('start inferring image: ' + dst_img_root + ' and updating the annotation file.')
dst = cv2.imread(dst_img_root)
for key in reid_infos:
reid_info = reid_infos[key]
mask = bas.mask(key, self.model)
dst = seamlessclone(dst, reid_info[0], mask, reid_info[1], reid_info[2])
# update the annotation .cache file
# self.annotation_update(dst_img_root, reid_info)
# save attached image
save_root = opt.save_root + dst_img_root.split('/')[dst_img_root.split('/').index('JPEGImages')+1]
cv2.imwrite(save_root, dst)
return dst
def main(self, opt, gt_datas, dst_img_root):
centers = []
area_gt_total = 0
num_gt = 0
dst_img = cv2.imread(dst_img_root)
dst_size = dst_img.shape
area_dst_total = dst_size[0] * dst_size[1]
# reidfile_root = '/home/cjs/reid_img_root.txt'
for gt in gt_datas:
# need to correct the gt_datas part
xmin, ymin, w, h = gt.x_top_left, gt.y_top_left, gt.width, gt.height
area = w * h
area_gt_total += area
num_gt += 1
if num_gt:
iouovertotal = area_gt_total / area_dst_total # area of gt over area of img, decide how many reid images to attach
gt_ave_area = area_gt_total/num_gt # average area of ground truth, used to scale the reid image
# estimate scale of h, w of reid
h_est = math.sqrt(gt_ave_area)
else:
iouovertotal = 0.0001
if iouovertotal > 0.5:
reid_num = np.random.randint(5,10) # number of reid img in one detect pic
else:
reid_num = np.random.randint(10,20) # number of reid img in one detect pic
trynum = 0
reid_infos = {}
while len(centers) < reid_num:
reidimg_roots_list = reid_img_select(reid_num, opt.reidfile_root)
for i in range(reid_num):
gt_datas_update = self.cache_dict[dst_img_root]
reid_img = cv2.imread(reidimg_roots_list[i])
if num_gt == 0:
scale = 1
else:
scale = math.sqrt(h_est / reid_img.shape[0])
center = iou_compute_getcenter(gt_datas_update, dst_size, reid_img.shape, scale)
if len(centers) >= reid_num:
break
elif center:
centers.append(center)
else:
continue
# include all data of reid needed
reid_info = [reid_img, center, scale]
reid_id = int(reidimg_roots_list[i].split('/')[-2])
reid_info.append(reid_id)
if 'zcyDataset30w' in reidimg_roots_list[i]:
reid_cls = 'car'
else:
reid_cls = 'pedestrian'
reid_info.append(reid_cls)
reid_infos[reidimg_roots_list[i]] = reid_info
# update the annotation .cache file
self.annotation_update(dst_img_root, reid_info)
trynum += 1
if trynum == 2:
break
img_attached = self.reid_attach(opt, dst_img_root, reid_infos)
return img_attached
def forward(self, opt):
# dst_paths = open(opt.cache_path, 'rb')
# lines = pickle.load(dst_paths)
# keys = lines.keys()
keys = self.cache_dict.keys()
for key in keys:
self.main(opt, self.cache_dict[key], key)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Detect attached reid image for data augmentation')
parser.add_argument('--reidfile_root', default='/home/cjs/reid_img_root.txt')
parser.add_argument('--save_root', default='/home/cjs/attached_images/')
parser.add_argument('--cache_path', default='/home/cjs/data/100train.cache')
# parser.add_argument('--cacheval_path', default='/home/cjs/data/50test.cache')
parser.add_argument('--segmodel_root', default="/home/cjs/SmallObjectDetection/BASNet/saved_models/basnet_bsi/basnet.pth")
# "/home/wx996846/cjs/SmallObjectDetection/BASNet/saved_models/basnet_bsi/basnet.pth"
opt = parser.parse_args()
A = ReidAttach(opt)
A.forward(opt)