2020-10-27 reid attach mission贴图任务

reid 贴图效果

代码功能:

对reid图片进行分割,并粘贴到目标检测图像中,算法可实现根据检测图像真值,选择合理数量的reid图片,进行算法缩放后进行粘贴

A. funtion statement:

1、iou_compute_getcenter(): return just one center coordinate: (x,y) or None if try 10 times but still cannot find a center

2、reid_img_select(): from /home/cjs/reid_img_root.txt get corrisponding number of reid image root

3、seamlesscloned(): using mask segmented by BASNet to achieve seaming the reid images onto detection images

4、class ReidAttach:

  • annotation_update(): update annotation of reid info. into gt of detection annotation
  • reid_attach(): attach all the selected reid images into one detection image in 80000 images
  • main(): main function
  • forward(): send all 80000 into main() funtion one by one

B. root:(file path and root can be changed according to your mission

  1. --reidfile_root:  '/home/cjs/reid_img_root.txt'           所有reid 图片路径的集合 .txt文件
  2. --save_root: '/home/cjs/attached_images/'              reid贴图后的保存位置
  3. --cache_path: '/home/cjs/data/100train.cache'        所有detection imags的路径集合 .cache文件
  4. --segmodel_root: "/home/cjs/SmallObjectDetection/BASNet/saved_models/basnet_bsi/basnet.pth")              分割算法的weights .pth文件

 

C. python code

# Standard imports
import cv2
import torch
import numpy as np
import math
from BASNet import basnet_test as bas
from BASNet.model.BASNet import BASNet
import argparse
import pickle
import copy


def centerhw2xymaxmin(centerx, centery, h, w):
    xmin = centerx - w/2
    xmax = centerx + w/2
    ymin = centery - h/2
    ymax = centery + h/2
    # area = h * w
    return xmin, xmax, ymin, ymax


def iou_compute_getcenter(gt_datas, dst_size, reid_size, scale):
    '''
    :param reid_num: number of reid img attach into a detect img
    :param gts_data: all the gt of the detect img, list type
    :param dst_size: detect img size
    :param reid_size: reid img size
    :param scale: scale of hw
    :return: center of reid img in the detect img
    '''

    try_num = 0     #  the time of trying to find a qualified center
    while try_num <= 10:     # try 10 times for every pic
        # random center of reid img
        center_x, center_y = int(np.random.random(1) * dst_size[1]), int(np.random.random(1) * dst_size[0])
        xminreid, xmaxreid, yminreid, ymaxreid = centerhw2xymaxmin(
            center_x, center_y, reid_size[0] * scale, reid_size[1] * scale)

        # judge if the reid image is out of detect image edge
        if xminreid < 0 or yminreid < 0 or xmaxreid > dst_size[1] or ymaxreid > dst_size[0]:
            continue
        # if gt_datas is []
        elif len(gt_datas) == 0:
            return center_x, center_y  # one reid image one center
        for gt in gt_datas:
            xmin, ymin, w, h = gt.x_top_left, gt.y_top_left, gt.width, gt.height
            xmax, ymax = xmin+w, ymin+h
            # judge if the attached reid image is overlap with detect image
            left_line = max(xmin, xminreid)
            right_line = min(xmax, xmaxreid)
            top_line = max(ymin, yminreid)
            bottom_line = min(ymax, ymaxreid)
            if left_line >= right_line or top_line >= bottom_line:
                if gt == gt_datas[-1]:
                    return center_x, center_y  # one reid image one center
            else:
                try_num += 1
                break
    return None


def reid_img_select(num, reidfile_root):
    f = open(reidfile_root, 'r')
    lines = f.readlines()
    num_lines = len(lines)
    reimgroot_list = []
    for i in range(num):
        random_num = np.random.randint(num_lines)
        line = lines[random_num]
        reimgroot_list.append(line[:-1])

    return reimgroot_list


def seamlessclone(dst, reidimg, src_mask, center, scale):
    # attach reid image onto detect image and save the new pic
    # detimg_dir = '/data/train/animal_01/images/vcm_hs_20190222_car1.mp4_000125.jpg'
    # reidimg_dir = '/data/ped_reid/ped_reid/F/0/0_0.jpg'
    # reidimg_dir = "/home/cjs/BASNet/test_data/test_images/opencv-seamless-cloning-example.jpg"
    # reidimg_dir ="/home/cjs/BASNet/test_data/test_images/2_3.jpg"
    # Read images
    # dst = cv2.imread(detimg_dir)
    reidimg = cv2.resize(reidimg, (int(scale * reidimg.shape[1]), int(scale * reidimg.shape[0])), interpolation=cv2.INTER_CUBIC)
    src_mask = cv2.resize(src_mask, (int(scale * src_mask.shape[1]), int(scale * src_mask.shape[0])), interpolation=cv2.INTER_CUBIC)
    h, w = src_mask.shape[0], src_mask.shape[1]

    # attach reid image into detect image
    dst_region = dst[int(center[1]-h/2):int(center[1]+h/2), int(center[0]-w/2):int(center[0]+w/2), :]
    dst_region[src_mask != 0] = 0
    reidimg[src_mask == 0] = 0
    dst_region += reidimg

    return dst


class ReidAttach:

    def __init__(self, opt):
        self.cache_dict = {}

        f1 = open(opt.cache_path, 'rb')
        # f2 = open(opt.cacheval_path, 'rb')
        train_line = pickle.load(f1)
        # val_line = pickle.load(f2)
        keys = train_line.keys()
        for i, key in enumerate(keys):
            self.cache_dict[key] = copy.deepcopy(train_line[key])
            if i == 0 and len(train_line[key]) != 0:
                self.sample = train_line[key][0]
        f1.close()
        # keys = val_line.keys()
        # for key in keys:
        #     self.val_dict[key] = copy.deepcopy(val_line[key])
        # f2.close()
        print('---BASNet loading---')
        self.model = BASNet(3, 1)
        self.model.load_state_dict(torch.load(opt.segmodel_root))
        print('--Model load finished--')

    def annotation_update(self, dst_img_root, reid_info):
        # update reid info into gt of detect image annotation:
        # [class_label, object_id, height, width, x_top_left, y_top_left]
        temp = copy.deepcopy(self.sample)
        temp.class_label = reid_info[4]
        temp.object_id = reid_info[3]
        temp.height = int(reid_info[0].shape[0] * reid_info[2])
        temp.width = int(reid_info[0].shape[1] * 2)
        temp.x_top_left = int(reid_info[1][0] - temp.width/2)
        temp.y_top_left = int(reid_info[1][1] - temp.height/2)

        self.cache_dict[dst_img_root].append(temp)

    def reid_attach(self, opt, dst_img_root, reid_infos):
        # attach reid image to dst image and update the id info into the gt
        """
        :param opt: include root to save attached image: attached train
        :param dst_img_root: destination image root
        :param reid_infos: 0.reid_img, 1.center, 2.scale 3.reid_id 4.reid_cls
        :return:  attached image, updated gt
        """
        print('start inferring image: ' + dst_img_root + ' and updating the annotation file.')
        dst = cv2.imread(dst_img_root)
        for key in reid_infos:
            reid_info = reid_infos[key]
            mask = bas.mask(key, self.model)
            dst = seamlessclone(dst, reid_info[0], mask, reid_info[1], reid_info[2])
            # update the annotation .cache file
            # self.annotation_update(dst_img_root, reid_info)
        # save attached image
        save_root = opt.save_root + dst_img_root.split('/')[dst_img_root.split('/').index('JPEGImages')+1]
        cv2.imwrite(save_root, dst)

        return dst

    def main(self, opt, gt_datas, dst_img_root):
        centers = []
        area_gt_total = 0
        num_gt = 0
        dst_img = cv2.imread(dst_img_root)
        dst_size = dst_img.shape
        area_dst_total = dst_size[0] * dst_size[1]
        # reidfile_root = '/home/cjs/reid_img_root.txt'

        for gt in gt_datas:
            # need to correct the gt_datas part
            xmin, ymin, w, h = gt.x_top_left, gt.y_top_left, gt.width, gt.height
            area = w * h
            area_gt_total += area
            num_gt += 1
        if num_gt:
            iouovertotal = area_gt_total / area_dst_total  # area of gt over area of img, decide how many reid images to attach
            gt_ave_area = area_gt_total/num_gt      # average area of ground truth, used to scale the reid image
            # estimate scale of h, w of reid
            h_est = math.sqrt(gt_ave_area)
        else:
            iouovertotal = 0.0001

        if iouovertotal > 0.5:
            reid_num = np.random.randint(5,10)     # number of reid img in one detect pic
        else:
            reid_num = np.random.randint(10,20)  # number of reid img in one detect pic
        trynum = 0
        reid_infos = {}
        while len(centers) < reid_num:
            reidimg_roots_list = reid_img_select(reid_num, opt.reidfile_root)
            for i in range(reid_num):
                gt_datas_update = self.cache_dict[dst_img_root]
                reid_img = cv2.imread(reidimg_roots_list[i])
                if num_gt == 0:
                    scale = 1
                else:
                    scale = math.sqrt(h_est / reid_img.shape[0])

                center = iou_compute_getcenter(gt_datas_update, dst_size, reid_img.shape, scale)
                if len(centers) >= reid_num:
                    break
                elif center:
                    centers.append(center)
                else:
                    continue
                # include all data of reid needed
                reid_info = [reid_img, center, scale]
                reid_id = int(reidimg_roots_list[i].split('/')[-2])
                reid_info.append(reid_id)
                if 'zcyDataset30w' in reidimg_roots_list[i]:
                    reid_cls = 'car'
                else:
                    reid_cls = 'pedestrian'
                reid_info.append(reid_cls)
                reid_infos[reidimg_roots_list[i]] = reid_info
                # update the annotation .cache file
                self.annotation_update(dst_img_root, reid_info)
            trynum += 1
            if trynum == 2:
                break
        img_attached = self.reid_attach(opt, dst_img_root, reid_infos)

        return img_attached

    def forward(self, opt):
        # dst_paths = open(opt.cache_path, 'rb')
        # lines = pickle.load(dst_paths)
        # keys = lines.keys()
        keys = self.cache_dict.keys()
        for key in keys:
            self.main(opt, self.cache_dict[key], key)


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Detect attached reid image for data augmentation')

    parser.add_argument('--reidfile_root', default='/home/cjs/reid_img_root.txt')
    parser.add_argument('--save_root', default='/home/cjs/attached_images/')
    parser.add_argument('--cache_path', default='/home/cjs/data/100train.cache')
    # parser.add_argument('--cacheval_path', default='/home/cjs/data/50test.cache')
    parser.add_argument('--segmodel_root', default="/home/cjs/SmallObjectDetection/BASNet/saved_models/basnet_bsi/basnet.pth")
    # "/home/wx996846/cjs/SmallObjectDetection/BASNet/saved_models/basnet_bsi/basnet.pth"
    opt = parser.parse_args()

    A = ReidAttach(opt)
    A.forward(opt)




 

 

 

 

 

 

 

 

 

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值