merge_CocoAndFix3

import os
from pycocotools.coco import COCO
import cv2
import numpy as np

alpha_1 = 0.25
alpha_2 = 0.65

cocoDir = r'D:\dataset\coco_2014'
dataType = ['train2014', 'val2014']

fixationDir = r'D:\dataset\salient_instance\fixation_maps_org'

outputDir = r'D:\dataset\salient_instance\coco_rank_dataset_2'
if not os.path.exists(outputDir):
    os.mkdir(outputDir)

annFiles = ['{}/annotations/instances_{}.json'.format(cocoDir, dType)
            for dType in dataType]
cocos = [COCO(annFile) for annFile in annFiles]
coco_img_names_list = [[img['file_name'] for img in
                        list(coco.imgs.values())] for coco in cocos]

fix_lists = os.listdir(fixationDir)
fix_imgs = dict()
for fix in fix_lists:
    coco = cocos[0] if fix.split('_')[1] == 'train2014' else cocos[1]
    coco_img_names = coco_img_names_list[0] if fix.split('_')[1] == 'train2014' \
        else coco_img_names_list[1]
    fix_img_name = '{}.jpg'.format(fix[:-4])
    img_idx = list(coco.imgs.keys())[coco_img_names.index(fix_img_name)]
    fix_imgs[img_idx] = {}
    fix_imgs[img_idx]['type'] = fix.split('_')[1]
    fix_imgs[img_idx]['img_name'] = fix_img_name

idxs = list(sorted(fix_imgs.keys()))
for idx in idxs:
    print(idx)
    # if idx != 475990:
    #     continue
    coco = cocos[0] if fix_imgs[idx]['type'] == 'train2014' else cocos[1]

    fix_img_name = '{}.png'.format(fix_imgs[idx]['img_name'][:-4])
    fix_img_path = os.path.join(fixationDir, fix_img_name)
    fix_image = cv2.imread(fix_img_path)[:, :, 0]
    fix_image[fix_image == 255] = 1
    all_fix = fix_image.sum().astype(np.float32)

    coco_image = coco.imgs[idx]
    coco_imageSize = (coco_image['height'], coco_image['width'])
    labelMap = np.zeros(coco_imageSize)
    check_mask = np.zeros(coco_imageSize)
    imgAnnots = [a for a in coco.anns.values() if a['image_id'] == idx]
    imgAnnots = sorted(imgAnnots, key=lambda x: x['area'])
    instances = {}
    for a in range(len(imgAnnots)):
        instances[a] = {}
        mask = coco.annToMask(imgAnnots[a])
        # --- test ---
        # mask_a = (mask == 1).astype(np.uint8)
        # mask_a[mask_a == 1] = 255
        # cv2.imwrite('mask_{}.png'.format(str(a)), mask_a)
        # --- test ---
        mask[mask == 1] = a+1
        all_overlay_mask = mask + check_mask
        all_overlay_mask[all_overlay_mask < a+1] = 0
        covered_region = all_overlay_mask - (a+1)
        covered_region[covered_region < 0] = 0
        labelMap_mask = check_mask != 0
        mask[labelMap_mask] = 0
        labelMask = mask == a+1
        check_mask[labelMask] = a+1
        fix_density = (fix_image[labelMask] == 1).sum() / all_fix
        instances[a]['fix_density'] = fix_density
        instances[a]['labelMask'] = labelMask
        instances[a]['covered_region'] = covered_region
    if len(instances) == 1:
        print('{} has only one instance'.format(idx))
        continue
    instances = sorted(instances.items(), key=lambda x: x[1]['fix_density'], reverse=True)

    instances_fix_set = [instance[1]['fix_density'] for instance in instances]

    all_salinet_fix = 0
    k = 0
    for i in range(len(instances_fix_set)):
        all_salinet_fix += instances_fix_set[i]
        if all_salinet_fix >= 1 - alpha_1:
            k = i
            break
    if all_salinet_fix < 1 - alpha_1:
        continue

    if k == len(instances_fix_set)-1:
        k_new = k
    else:
        if ((instances_fix_set[k] - instances_fix_set[k+1])
            / instances_fix_set[k]) < alpha_2:
            k_up = k - 1
            while k_up >= 0:
                if ((instances_fix_set[k_up] - instances_fix_set[k_up + 1])
                    / instances_fix_set[k_up]) < alpha_2:
                    k_up -= 1
                else:
                    break
            k_down = k + 1
            while k_down <= len(instances_fix_set) - 2:
                if ((instances_fix_set[k_down] - instances_fix_set[k_down + 1])
                    / instances_fix_set[k_down]) < alpha_2:
                    k_down += 1
                else:
                    break
            if k - k_up < k_down - k:
                k_new = k_up
            else:
                k_new = k_down
        else:
            k_new = k

    if k_new <= 0 or k_new >= len(instances_fix_set):
        continue

    delete_idxs = [instance[0] for instance in instances[k_new+1:]]

    for i in range(k_new+1):
        rank = int(255./(k_new+1)*(k_new+1-i))
        covered_region = instances[i][1]['covered_region']
        cover_idxs = np.unique(covered_region)[1:]-1
        for cover_idx in cover_idxs:
            if cover_idx not in delete_idxs:
                covered_region[covered_region == cover_idx+1] = 0
        covered_mask = covered_region != 0
        labelMap[instances[i][1]['labelMask']] = rank
        labelMap[covered_mask] = rank
    out_dir = os.path.join(outputDir, fix_img_name)
    cv2.imwrite(out_dir, labelMap.astype(np.uint8))

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值